move bulk cmd/script to nats

This commit is contained in:
wh1te909 2020-11-24 04:09:52 +00:00
parent a510854741
commit e90e527603
3 changed files with 94 additions and 73 deletions

View File

@ -7,9 +7,7 @@ from itertools import cycle
from django.conf import settings
from django.utils import timezone as djangotime
from rest_framework.authtoken.models import Token
from accounts.models import User
from tacticalrmm.test import TacticalTestCase
from .serializers import AgentSerializer
from winupdate.serializers import WinUpdatePolicySerializer
@ -530,12 +528,14 @@ class TestAgentViews(TacticalTestCase):
self.check_not_authenticated("get", url)
@patch("winupdate.tasks.bulk_check_for_updates_task.delay")
@patch("scripts.tasks.handle_bulk_script_task.delay")
@patch("scripts.tasks.handle_bulk_command_task.delay")
@patch("agents.models.Agent.salt_batch_async")
def test_bulk_cmd_script(self, mock_ret, mock_update):
def test_bulk_cmd_script(
self, salt_batch_async, bulk_command, bulk_script, mock_update
):
url = "/agents/bulk/"
mock_ret.return_value = "ok"
payload = {
"mode": "command",
"target": "agents",
@ -550,6 +550,7 @@ class TestAgentViews(TacticalTestCase):
}
r = self.client.post(url, payload, format="json")
bulk_command.assert_called_with([self.agent.pk], "gpupdate /force", "cmd", 300)
self.assertEqual(r.status_code, 200)
payload = {
@ -581,6 +582,7 @@ class TestAgentViews(TacticalTestCase):
r = self.client.post(url, payload, format="json")
self.assertEqual(r.status_code, 200)
bulk_command.assert_called_with([self.agent.pk], "gpupdate /force", "cmd", 300)
payload = {
"mode": "command",
@ -597,12 +599,7 @@ class TestAgentViews(TacticalTestCase):
r = self.client.post(url, payload, format="json")
self.assertEqual(r.status_code, 200)
mock_ret.return_value = "timeout"
payload["client"] = self.agent.client.id
payload["site"] = self.agent.site.id
r = self.client.post(url, payload, format="json")
self.assertEqual(r.status_code, 400)
bulk_command.assert_called_with([self.agent.pk], "gpupdate /force", "cmd", 300)
payload = {
"mode": "scan",
@ -613,9 +610,8 @@ class TestAgentViews(TacticalTestCase):
self.agent.pk,
],
}
mock_ret.return_value = "ok"
r = self.client.post(url, payload, format="json")
mock_update.assert_called_once()
mock_update.assert_called_with(minions=[self.agent.salt_id])
self.assertEqual(r.status_code, 200)
payload = {
@ -627,6 +623,7 @@ class TestAgentViews(TacticalTestCase):
self.agent.pk,
],
}
salt_batch_async.return_value = "ok"
r = self.client.post(url, payload, format="json")
self.assertEqual(r.status_code, 200)

View File

@ -32,7 +32,7 @@ from winupdate.serializers import WinUpdatePolicySerializer
from .tasks import uninstall_agent_task, send_agent_update_task
from winupdate.tasks import bulk_check_for_updates_task
from scripts.tasks import run_bulk_script_task
from scripts.tasks import handle_bulk_command_task, handle_bulk_script_task
from tacticalrmm.utils import notify_error, reload_nats
@ -760,73 +760,44 @@ def bulk(request):
return notify_error("Must select at least 1 agent")
if request.data["target"] == "client":
agents = Agent.objects.filter(site__client_id=request.data["client"])
q = Agent.objects.filter(site__client_id=request.data["client"])
elif request.data["target"] == "site":
agents = Agent.objects.filter(site_id=request.data["site"])
q = Agent.objects.filter(site_id=request.data["site"])
elif request.data["target"] == "agents":
agents = Agent.objects.filter(pk__in=request.data["agentPKs"])
q = Agent.objects.filter(pk__in=request.data["agentPKs"])
elif request.data["target"] == "all":
agents = Agent.objects.all()
q = Agent.objects.all()
else:
return notify_error("Something went wrong")
minions = [agent.salt_id for agent in agents]
minions = [agent.salt_id for agent in q]
agents = [agent.pk for agent in q]
AuditLog.audit_bulk_action(request.user, request.data["mode"], request.data)
if request.data["mode"] == "command":
r = Agent.salt_batch_async(
minions=minions,
func="cmd.run_bg",
kwargs={
"cmd": request.data["cmd"],
"shell": request.data["shell"],
"timeout": request.data["timeout"],
},
handle_bulk_command_task.delay(
agents, request.data["cmd"], request.data["shell"], request.data["timeout"]
)
if r == "timeout":
return notify_error("Salt API not running")
return Response(f"Command will now be run on {len(minions)} agents")
return Response(f"Command will now be run on {len(agents)} agents")
elif request.data["mode"] == "script":
script = get_object_or_404(Script, pk=request.data["scriptPK"])
if script.shell == "python":
r = Agent.salt_batch_async(
minions=minions,
func="win_agent.run_script",
kwargs={
"filepath": script.filepath,
"filename": script.filename,
"shell": script.shell,
"timeout": request.data["timeout"],
"args": request.data["args"],
"bg": True,
},
)
if r == "timeout":
return notify_error("Salt API not running")
else:
data = {
"minions": minions,
"scriptpk": script.pk,
"timeout": request.data["timeout"],
"args": request.data["args"],
}
run_bulk_script_task.delay(data)
return Response(f"{script.name} will now be run on {len(minions)} agents")
handle_bulk_script_task.delay(
script.pk, agents, request.data["args"], request.data["timeout"]
)
return Response(f"{script.name} will now be run on {len(agents)} agents")
elif request.data["mode"] == "install":
r = Agent.salt_batch_async(minions=minions, func="win_agent.install_updates")
if r == "timeout":
return notify_error("Salt API not running")
return Response(
f"Pending updates will now be installed on {len(minions)} agents"
f"Pending updates will now be installed on {len(agents)} agents"
)
elif request.data["mode"] == "scan":
bulk_check_for_updates_task.delay(minions=minions)
return Response(f"Patch status scan will now run on {len(minions)} agents")
return Response(f"Patch status scan will now run on {len(agents)} agents")
return notify_error("Something went wrong")

View File

@ -1,21 +1,74 @@
import asyncio
from tacticalrmm.celery import app
from agents.models import Agent
from .models import Script
from scripts.models import Script
@app.task
def run_bulk_script_task(data):
# for powershell and batch scripts only, workaround for salt bg script bug
script = Script.objects.get(pk=data["scriptpk"])
def handle_bulk_command_task(agentpks, cmd, shell, timeout):
agents = Agent.objects.filter(pk__in=agentpks)
Agent.salt_batch_async(
minions=data["minions"],
func="win_agent.run_script",
kwargs={
"filepath": script.filepath,
"filename": script.filename,
"shell": script.shell,
"timeout": data["timeout"],
"args": data["args"],
},
)
agents_nats = [agent for agent in agents if agent.has_nats]
agents_salt = [agent for agent in agents if not agent.has_nats]
minions = [agent.salt_id for agent in agents_salt]
if minions:
Agent.salt_batch_async(
minions=minions,
func="cmd.run_bg",
kwargs={
"cmd": cmd,
"shell": shell,
"timeout": timeout,
},
)
if agents_nats:
nats_data = {
"func": "rawcmd",
"timeout": timeout,
"payload": {
"command": cmd,
"shell": shell,
},
}
for agent in agents_nats:
asyncio.run(agent.nats_cmd(nats_data, wait=False))
@app.task
def handle_bulk_script_task(scriptpk, agentpks, args, timeout):
script = Script.objects.get(pk=scriptpk)
agents = Agent.objects.filter(pk__in=agentpks)
agents_nats = [agent for agent in agents if agent.has_nats]
agents_salt = [agent for agent in agents if not agent.has_nats]
minions = [agent.salt_id for agent in agents_salt]
if minions:
Agent.salt_batch_async(
minions=minions,
func="win_agent.run_script",
kwargs={
"filepath": script.filepath,
"filename": script.filename,
"shell": script.shell,
"timeout": timeout,
"args": args,
"bg": True if script.shell == "python" else False, # salt bg script bug
},
)
if agents_nats:
nats_data = {
"func": "runscript",
"timeout": timeout,
"script_args": args,
"payload": {
"code": script.code,
"shell": script.shell,
},
}
for agent in agents_nats:
asyncio.run(agent.nats_cmd(nats_data, wait=False))