diff --git a/api/tacticalrmm/agents/tests.py b/api/tacticalrmm/agents/tests.py index 9213df26..2c53f947 100644 --- a/api/tacticalrmm/agents/tests.py +++ b/api/tacticalrmm/agents/tests.py @@ -7,9 +7,7 @@ from itertools import cycle from django.conf import settings from django.utils import timezone as djangotime -from rest_framework.authtoken.models import Token -from accounts.models import User from tacticalrmm.test import TacticalTestCase from .serializers import AgentSerializer from winupdate.serializers import WinUpdatePolicySerializer @@ -530,12 +528,14 @@ class TestAgentViews(TacticalTestCase): self.check_not_authenticated("get", url) @patch("winupdate.tasks.bulk_check_for_updates_task.delay") + @patch("scripts.tasks.handle_bulk_script_task.delay") + @patch("scripts.tasks.handle_bulk_command_task.delay") @patch("agents.models.Agent.salt_batch_async") - def test_bulk_cmd_script(self, mock_ret, mock_update): + def test_bulk_cmd_script( + self, salt_batch_async, bulk_command, bulk_script, mock_update + ): url = "/agents/bulk/" - mock_ret.return_value = "ok" - payload = { "mode": "command", "target": "agents", @@ -550,6 +550,7 @@ class TestAgentViews(TacticalTestCase): } r = self.client.post(url, payload, format="json") + bulk_command.assert_called_with([self.agent.pk], "gpupdate /force", "cmd", 300) self.assertEqual(r.status_code, 200) payload = { @@ -581,6 +582,7 @@ class TestAgentViews(TacticalTestCase): r = self.client.post(url, payload, format="json") self.assertEqual(r.status_code, 200) + bulk_command.assert_called_with([self.agent.pk], "gpupdate /force", "cmd", 300) payload = { "mode": "command", @@ -597,12 +599,7 @@ class TestAgentViews(TacticalTestCase): r = self.client.post(url, payload, format="json") self.assertEqual(r.status_code, 200) - - mock_ret.return_value = "timeout" - payload["client"] = self.agent.client.id - payload["site"] = self.agent.site.id - r = self.client.post(url, payload, format="json") - self.assertEqual(r.status_code, 400) + bulk_command.assert_called_with([self.agent.pk], "gpupdate /force", "cmd", 300) payload = { "mode": "scan", @@ -613,9 +610,8 @@ class TestAgentViews(TacticalTestCase): self.agent.pk, ], } - mock_ret.return_value = "ok" r = self.client.post(url, payload, format="json") - mock_update.assert_called_once() + mock_update.assert_called_with(minions=[self.agent.salt_id]) self.assertEqual(r.status_code, 200) payload = { @@ -627,6 +623,7 @@ class TestAgentViews(TacticalTestCase): self.agent.pk, ], } + salt_batch_async.return_value = "ok" r = self.client.post(url, payload, format="json") self.assertEqual(r.status_code, 200) diff --git a/api/tacticalrmm/agents/views.py b/api/tacticalrmm/agents/views.py index 9c8bd3ac..1cbb5811 100644 --- a/api/tacticalrmm/agents/views.py +++ b/api/tacticalrmm/agents/views.py @@ -32,7 +32,7 @@ from winupdate.serializers import WinUpdatePolicySerializer from .tasks import uninstall_agent_task, send_agent_update_task from winupdate.tasks import bulk_check_for_updates_task -from scripts.tasks import run_bulk_script_task +from scripts.tasks import handle_bulk_command_task, handle_bulk_script_task from tacticalrmm.utils import notify_error, reload_nats @@ -760,73 +760,44 @@ def bulk(request): return notify_error("Must select at least 1 agent") if request.data["target"] == "client": - agents = Agent.objects.filter(site__client_id=request.data["client"]) + q = Agent.objects.filter(site__client_id=request.data["client"]) elif request.data["target"] == "site": - agents = Agent.objects.filter(site_id=request.data["site"]) + q = Agent.objects.filter(site_id=request.data["site"]) elif request.data["target"] == "agents": - agents = Agent.objects.filter(pk__in=request.data["agentPKs"]) + q = Agent.objects.filter(pk__in=request.data["agentPKs"]) elif request.data["target"] == "all": - agents = Agent.objects.all() + q = Agent.objects.all() else: return notify_error("Something went wrong") - minions = [agent.salt_id for agent in agents] + minions = [agent.salt_id for agent in q] + agents = [agent.pk for agent in q] AuditLog.audit_bulk_action(request.user, request.data["mode"], request.data) if request.data["mode"] == "command": - r = Agent.salt_batch_async( - minions=minions, - func="cmd.run_bg", - kwargs={ - "cmd": request.data["cmd"], - "shell": request.data["shell"], - "timeout": request.data["timeout"], - }, + handle_bulk_command_task.delay( + agents, request.data["cmd"], request.data["shell"], request.data["timeout"] ) - if r == "timeout": - return notify_error("Salt API not running") - return Response(f"Command will now be run on {len(minions)} agents") + return Response(f"Command will now be run on {len(agents)} agents") elif request.data["mode"] == "script": script = get_object_or_404(Script, pk=request.data["scriptPK"]) - - if script.shell == "python": - r = Agent.salt_batch_async( - minions=minions, - func="win_agent.run_script", - kwargs={ - "filepath": script.filepath, - "filename": script.filename, - "shell": script.shell, - "timeout": request.data["timeout"], - "args": request.data["args"], - "bg": True, - }, - ) - if r == "timeout": - return notify_error("Salt API not running") - else: - data = { - "minions": minions, - "scriptpk": script.pk, - "timeout": request.data["timeout"], - "args": request.data["args"], - } - run_bulk_script_task.delay(data) - - return Response(f"{script.name} will now be run on {len(minions)} agents") + handle_bulk_script_task.delay( + script.pk, agents, request.data["args"], request.data["timeout"] + ) + return Response(f"{script.name} will now be run on {len(agents)} agents") elif request.data["mode"] == "install": r = Agent.salt_batch_async(minions=minions, func="win_agent.install_updates") if r == "timeout": return notify_error("Salt API not running") return Response( - f"Pending updates will now be installed on {len(minions)} agents" + f"Pending updates will now be installed on {len(agents)} agents" ) elif request.data["mode"] == "scan": bulk_check_for_updates_task.delay(minions=minions) - return Response(f"Patch status scan will now run on {len(minions)} agents") + return Response(f"Patch status scan will now run on {len(agents)} agents") return notify_error("Something went wrong") diff --git a/api/tacticalrmm/scripts/tasks.py b/api/tacticalrmm/scripts/tasks.py index cff93b71..31d0320b 100644 --- a/api/tacticalrmm/scripts/tasks.py +++ b/api/tacticalrmm/scripts/tasks.py @@ -1,21 +1,74 @@ +import asyncio + from tacticalrmm.celery import app from agents.models import Agent -from .models import Script +from scripts.models import Script @app.task -def run_bulk_script_task(data): - # for powershell and batch scripts only, workaround for salt bg script bug - script = Script.objects.get(pk=data["scriptpk"]) +def handle_bulk_command_task(agentpks, cmd, shell, timeout): + agents = Agent.objects.filter(pk__in=agentpks) - Agent.salt_batch_async( - minions=data["minions"], - func="win_agent.run_script", - kwargs={ - "filepath": script.filepath, - "filename": script.filename, - "shell": script.shell, - "timeout": data["timeout"], - "args": data["args"], - }, - ) + agents_nats = [agent for agent in agents if agent.has_nats] + agents_salt = [agent for agent in agents if not agent.has_nats] + minions = [agent.salt_id for agent in agents_salt] + + if minions: + Agent.salt_batch_async( + minions=minions, + func="cmd.run_bg", + kwargs={ + "cmd": cmd, + "shell": shell, + "timeout": timeout, + }, + ) + + if agents_nats: + nats_data = { + "func": "rawcmd", + "timeout": timeout, + "payload": { + "command": cmd, + "shell": shell, + }, + } + for agent in agents_nats: + asyncio.run(agent.nats_cmd(nats_data, wait=False)) + + +@app.task +def handle_bulk_script_task(scriptpk, agentpks, args, timeout): + script = Script.objects.get(pk=scriptpk) + agents = Agent.objects.filter(pk__in=agentpks) + + agents_nats = [agent for agent in agents if agent.has_nats] + agents_salt = [agent for agent in agents if not agent.has_nats] + minions = [agent.salt_id for agent in agents_salt] + + if minions: + Agent.salt_batch_async( + minions=minions, + func="win_agent.run_script", + kwargs={ + "filepath": script.filepath, + "filename": script.filename, + "shell": script.shell, + "timeout": timeout, + "args": args, + "bg": True if script.shell == "python" else False, # salt bg script bug + }, + ) + + if agents_nats: + nats_data = { + "func": "runscript", + "timeout": timeout, + "script_args": args, + "payload": { + "code": script.code, + "shell": script.shell, + }, + } + for agent in agents_nats: + asyncio.run(agent.nats_cmd(nats_data, wait=False))