diff --git a/api/tacticalrmm/automation/models.py b/api/tacticalrmm/automation/models.py index 2b30dd51..1db46c12 100644 --- a/api/tacticalrmm/automation/models.py +++ b/api/tacticalrmm/automation/models.py @@ -78,6 +78,7 @@ class Policy(models.Model): # List of all tasks to be applied tasks = list() + added_task_pks = list() # Get policies applied to agent and agent site and client client = Client.objects.get(client=agent.client) @@ -97,13 +98,19 @@ class Policy(models.Model): if agent_policy and agent_policy.active: for task in agent_policy.autotasks.all(): - tasks.append(task) + if task.pk not in added_task_pks: + tasks.append(task) + added_task_pks.append(task.pk) if site_policy and site_policy.active: for task in site_policy.autotasks.all(): - tasks.append(task) + if task.pk not in added_task_pks: + tasks.append(task) + added_task_pks.append(task.pk) if client_policy and client_policy.active: for task in client_policy.autotasks.all(): - tasks.append(task) + if task.pk not in added_task_pks: + tasks.append(task) + added_task_pks.append(task.pk) return tasks diff --git a/api/tacticalrmm/automation/tasks.py b/api/tacticalrmm/automation/tasks.py index cc339d1e..c61bd8a9 100644 --- a/api/tacticalrmm/automation/tasks.py +++ b/api/tacticalrmm/automation/tasks.py @@ -9,7 +9,7 @@ from tacticalrmm.celery import app @app.task def generate_agent_checks_from_policies_task( - policypk, many=False, clear=False, parent_checks=[] + policypk, many=False, clear=False, parent_checks=[], create_tasks=False ): if many: @@ -19,22 +19,29 @@ def generate_agent_checks_from_policies_task( agent.generate_checks_from_policies( clear=clear, parent_checks=parent_checks ) + if create_tasks: + agent.generate_tasks_from_policies(clear=clear,) else: policy = Policy.objects.get(pk=policypk) for agent in policy.related_agents(): agent.generate_checks_from_policies( clear=clear, parent_checks=parent_checks ) + if create_tasks: + agent.generate_tasks_from_policies(clear=clear,) @app.task def generate_agent_checks_by_location_task( - location, mon_type, clear=False, parent_checks=[] + location, mon_type, clear=False, parent_checks=[], create_tasks=False ): for agent in Agent.objects.filter(**location).filter(monitoring_type=mon_type): agent.generate_checks_from_policies(clear=clear, parent_checks=parent_checks) + if create_tasks: + agent.generate_tasks_from_policies(clear=clear) + @app.task def delete_policy_check_task(checkpk): diff --git a/api/tacticalrmm/automation/tests.py b/api/tacticalrmm/automation/tests.py index 904ff884..ee16e4e2 100644 --- a/api/tacticalrmm/automation/tests.py +++ b/api/tacticalrmm/automation/tests.py @@ -68,8 +68,7 @@ class TestPolicyViews(BaseTestCase): self.check_not_authenticated("post", url) @patch("automation.tasks.generate_agent_checks_from_policies_task.delay") - @patch("automation.tasks.generate_agent_tasks_from_policies_task.delay") - def test_update_policy(self, mock_tasks_task, mock_checks_task): + def test_update_policy(self, mock_checks_task): url = f"/automation/policies/{self.policy.pk}/" valid_payload = { @@ -84,7 +83,6 @@ class TestPolicyViews(BaseTestCase): # only called if active or enforced are updated mock_checks_task.assert_not_called() - mock_tasks_task.assert_not_called() valid_payload = { "name": "Test Policy Update", @@ -95,8 +93,9 @@ class TestPolicyViews(BaseTestCase): resp = self.client.put(url, valid_payload, format="json") self.assertEqual(resp.status_code, 200) - mock_checks_task.assert_called_with(policypk=self.policy.pk, clear=True) - mock_tasks_task.assert_called_with(policypk=self.policy.pk, clear=True) + mock_checks_task.assert_called_with( + policypk=self.policy.pk, clear=True, create_tasks=True + ) self.check_not_authenticated("put", url) @@ -178,14 +177,8 @@ class TestPolicyViews(BaseTestCase): @patch("agents.models.Agent.generate_checks_from_policies") @patch("automation.tasks.generate_agent_checks_by_location_task.delay") - @patch("agents.models.Agent.generate_tasks_from_policies") - @patch("automation.tasks.generate_agent_tasks_by_location_task.delay") def test_update_policy_add( - self, - mock_tasks_location_task, - mock_tasks_task, - mock_checks_location_task, - mock_checks_task, + self, mock_checks_location_task, mock_checks_task, ): url = f"/automation/related/" @@ -227,30 +220,26 @@ class TestPolicyViews(BaseTestCase): # called because the relation changed mock_checks_location_task.assert_called_with( - location={"client": client.client}, mon_type="server", clear=True + location={"client": client.client}, + mon_type="server", + clear=True, + create_tasks=True, ) mock_checks_location_task.reset_mock() - mock_tasks_location_task.assert_called_with( - location={"client": client.client}, mon_type="server", clear=True - ) - mock_tasks_location_task.reset_mock() - # test client workstation policy add resp = self.client.post(url, client_workstation_payload, format="json") self.assertEqual(resp.status_code, 200) # called because the relation changed mock_checks_location_task.assert_called_with( - location={"client": client.client}, mon_type="workstation", clear=True + location={"client": client.client}, + mon_type="workstation", + clear=True, + create_tasks=True, ) mock_checks_location_task.reset_mock() - mock_tasks_location_task.assert_called_with( - location={"client": client.client}, mon_type="workstation", clear=True - ) - mock_tasks_location_task.reset_mock() - # test site add server policy resp = self.client.post(url, site_server_payload, format="json") self.assertEqual(resp.status_code, 200) @@ -260,16 +249,10 @@ class TestPolicyViews(BaseTestCase): location={"client": site.client.client, "site": site.site}, mon_type="server", clear=True, + create_tasks=True, ) mock_checks_location_task.reset_mock() - mock_tasks_location_task.assert_called_with( - location={"client": site.client.client, "site": site.site}, - mon_type="server", - clear=True, - ) - mock_tasks_location_task.reset_mock() - # test site add workstation policy resp = self.client.post(url, site_workstation_payload, format="json") self.assertEqual(resp.status_code, 200) @@ -279,16 +262,10 @@ class TestPolicyViews(BaseTestCase): location={"client": site.client.client, "site": site.site}, mon_type="workstation", clear=True, + create_tasks=True, ) mock_checks_location_task.reset_mock() - mock_tasks_location_task.assert_called_with( - location={"client": site.client.client, "site": site.site}, - mon_type="workstation", - clear=True, - ) - mock_tasks_location_task.reset_mock() - # test agent add resp = self.client.post(url, agent_payload, format="json") self.assertEqual(resp.status_code, 200) @@ -297,30 +274,24 @@ class TestPolicyViews(BaseTestCase): mock_checks_task.assert_called_with(clear=True) mock_checks_task.reset_mock() - mock_tasks_task.assert_called_with(clear=True) - mock_tasks_task.reset_mock() - # Adding the same relations shouldn't trigger mocks resp = self.client.post(url, client_server_payload, format="json") self.assertEqual(resp.status_code, 200) resp = self.client.post(url, client_workstation_payload, format="json") self.assertEqual(resp.status_code, 200) mock_checks_location_task.assert_not_called() - mock_tasks_location_task.assert_not_called() resp = self.client.post(url, site_server_payload, format="json") self.assertEqual(resp.status_code, 200) resp = self.client.post(url, site_workstation_payload, format="json") self.assertEqual(resp.status_code, 200) mock_checks_location_task.assert_not_called() - mock_tasks_location_task.assert_not_called() resp = self.client.post(url, agent_payload, format="json") self.assertEqual(resp.status_code, 200) # called because the relation changed mock_checks_task.assert_not_called() - mock_tasks_task.assert_not_called() # test remove client from policy data client_server_payload = {"type": "client", "pk": client.pk, "server_policy": 0} @@ -347,30 +318,26 @@ class TestPolicyViews(BaseTestCase): # called because the relation changed mock_checks_location_task.assert_called_with( - location={"client": client.client}, mon_type="server", clear=True + location={"client": client.client}, + mon_type="server", + clear=True, + create_tasks=True, ) mock_checks_location_task.reset_mock() - mock_tasks_location_task.assert_called_with( - location={"client": client.client}, mon_type="server", clear=True - ) - mock_tasks_location_task.reset_mock() - # test client workstation policy remove resp = self.client.post(url, client_workstation_payload, format="json") self.assertEqual(resp.status_code, 200) # called because the relation changed mock_checks_location_task.assert_called_with( - location={"client": client.client}, mon_type="workstation", clear=True + location={"client": client.client}, + mon_type="workstation", + clear=True, + create_tasks=True, ) mock_checks_location_task.reset_mock() - mock_tasks_location_task.assert_called_with( - location={"client": client.client}, mon_type="workstation", clear=True - ) - mock_tasks_location_task.reset_mock() - # test site remove server policy resp = self.client.post(url, site_server_payload, format="json") self.assertEqual(resp.status_code, 200) @@ -380,16 +347,10 @@ class TestPolicyViews(BaseTestCase): location={"client": site.client.client, "site": site.site}, mon_type="server", clear=True, + create_tasks=True, ) mock_checks_location_task.reset_mock() - mock_tasks_location_task.assert_called_with( - location={"client": site.client.client, "site": site.site}, - mon_type="server", - clear=True, - ) - mock_tasks_location_task.reset_mock() - # test site remove workstation policy resp = self.client.post(url, site_workstation_payload, format="json") self.assertEqual(resp.status_code, 200) @@ -399,16 +360,10 @@ class TestPolicyViews(BaseTestCase): location={"client": site.client.client, "site": site.site}, mon_type="workstation", clear=True, + create_tasks=True, ) mock_checks_location_task.reset_mock() - mock_tasks_location_task.assert_called_with( - location={"client": site.client.client, "site": site.site}, - mon_type="workstation", - clear=True, - ) - mock_tasks_location_task.reset_mock() - # test agent remove resp = self.client.post(url, agent_payload, format="json") self.assertEqual(resp.status_code, 200) @@ -416,9 +371,6 @@ class TestPolicyViews(BaseTestCase): mock_checks_task.assert_called_with(clear=True) mock_checks_task.reset_mock() - mock_tasks_task.assert_called_with(clear=True) - mock_tasks_task.reset_mock() - # adding the same relations shouldn't trigger mocks resp = self.client.post(url, client_server_payload, format="json") self.assertEqual(resp.status_code, 200) @@ -428,7 +380,6 @@ class TestPolicyViews(BaseTestCase): # shouldn't be called since nothing changed mock_checks_location_task.assert_not_called() - mock_tasks_location_task.assert_not_called() resp = self.client.post(url, site_server_payload, format="json") self.assertEqual(resp.status_code, 200) @@ -438,14 +389,12 @@ class TestPolicyViews(BaseTestCase): # shouldn't be called since nothing changed mock_checks_location_task.assert_not_called() - mock_tasks_location_task.assert_not_called() resp = self.client.post(url, agent_payload, format="json") self.assertEqual(resp.status_code, 200) # shouldn't be called since nothing changed mock_checks_task.assert_not_called() - mock_tasks_task.assert_not_called() self.check_not_authenticated("post", url) diff --git a/api/tacticalrmm/automation/views.py b/api/tacticalrmm/automation/views.py index 036ceb31..4ea40b3e 100644 --- a/api/tacticalrmm/automation/views.py +++ b/api/tacticalrmm/automation/views.py @@ -1,6 +1,8 @@ from django.db import DataError from django.shortcuts import get_object_or_404 +from celery import chain + from rest_framework.views import APIView from rest_framework.response import Response from rest_framework import status @@ -74,12 +76,7 @@ class GetUpdateDeletePolicy(APIView): generate_agent_checks_from_policies_task.delay( policypk=policy.pk, clear=(not saved_policy.active or not saved_policy.enforced), - ) - - # Genereate agent tasks if active was changed - if saved_policy.active != old_active: - generate_agent_tasks_from_policies_task.delay( - policypk=policy.pk, clear=(not saved_policy.active) + create_tasks=(saved_policy.active != old_active), ) return Response("ok") @@ -202,15 +199,12 @@ class GetRelated(APIView): ): client.workstation_policy = policy client.save() + generate_agent_checks_by_location_task.delay( location={"client": client.client}, mon_type="workstation", clear=True, - ) - generate_agent_tasks_by_location_task.delay( - location={"client": client.client}, - mon_type="workstation", - clear=True, + create_tasks=True, ) if related_type == "site": @@ -228,11 +222,7 @@ class GetRelated(APIView): location={"client": site.client.client, "site": site.site}, mon_type="workstation", clear=True, - ) - generate_agent_tasks_by_location_task.delay( - location={"client": site.client.client, "site": site.site}, - mon_type="workstation", - clear=True, + create_tasks=True, ) # server policy is set @@ -254,11 +244,7 @@ class GetRelated(APIView): location={"client": client.client}, mon_type="server", clear=True, - ) - generate_agent_tasks_by_location_task.delay( - location={"client": client.client}, - mon_type="server", - clear=True, + create_tasks=True, ) if related_type == "site": @@ -276,11 +262,7 @@ class GetRelated(APIView): location={"client": site.client.client, "site": site.site}, mon_type="server", clear=True, - ) - generate_agent_tasks_by_location_task.delay( - location={"client": site.client.client, "site": site.site}, - mon_type="server", - clear=True, + create_tasks=True, ) # If workstation policy was cleared @@ -300,11 +282,7 @@ class GetRelated(APIView): location={"client": client.client}, mon_type="workstation", clear=True, - ) - generate_agent_tasks_by_location_task.delay( - location={"client": client.client}, - mon_type="workstation", - clear=True, + create_tasks=True, ) if related_type == "site": @@ -319,11 +297,7 @@ class GetRelated(APIView): location={"client": site.client.client, "site": site.site}, mon_type="workstation", clear=True, - ) - generate_agent_tasks_by_location_task.delay( - location={"client": site.client.client, "site": site.site}, - mon_type="workstation", - clear=True, + create_tasks=True, ) # server policy cleared @@ -341,11 +315,7 @@ class GetRelated(APIView): location={"client": client.client}, mon_type="server", clear=True, - ) - generate_agent_tasks_by_location_task.delay( - location={"client": client.client}, - mon_type="server", - clear=True, + create_tasks=True, ) if related_type == "site": @@ -359,11 +329,7 @@ class GetRelated(APIView): location={"client": site.client.client, "site": site.site}, mon_type="server", clear=True, - ) - generate_agent_tasks_by_location_task.delay( - location={"client": site.client.client, "site": site.site}, - mon_type="server", - clear=True, + create_tasks=True, ) # agent policies diff --git a/api/tacticalrmm/autotasks/models.py b/api/tacticalrmm/autotasks/models.py index 1aff72e4..645afd5a 100644 --- a/api/tacticalrmm/autotasks/models.py +++ b/api/tacticalrmm/autotasks/models.py @@ -109,7 +109,10 @@ class AutomatedTask(models.Model): return "TacticalRMM_" + "".join(random.choice(chars) for i in range(35)) def create_policy_task(self, agent): - assigned_check = agent.agentchecks.get(parent_check=self.assigned_check.pk) + assigned_check = None + if self.assigned_check: + assigned_check = agent.agentchecks.get(parent_check=self.assigned_check.pk) + task = AutomatedTask.objects.create( agent=agent, managed_by_policy=True,