move bulk cmd/script to nats
This commit is contained in:
parent
a510854741
commit
e90e527603
|
@ -7,9 +7,7 @@ from itertools import cycle
|
|||
|
||||
from django.conf import settings
|
||||
from django.utils import timezone as djangotime
|
||||
from rest_framework.authtoken.models import Token
|
||||
|
||||
from accounts.models import User
|
||||
from tacticalrmm.test import TacticalTestCase
|
||||
from .serializers import AgentSerializer
|
||||
from winupdate.serializers import WinUpdatePolicySerializer
|
||||
|
@ -530,12 +528,14 @@ class TestAgentViews(TacticalTestCase):
|
|||
self.check_not_authenticated("get", url)
|
||||
|
||||
@patch("winupdate.tasks.bulk_check_for_updates_task.delay")
|
||||
@patch("scripts.tasks.handle_bulk_script_task.delay")
|
||||
@patch("scripts.tasks.handle_bulk_command_task.delay")
|
||||
@patch("agents.models.Agent.salt_batch_async")
|
||||
def test_bulk_cmd_script(self, mock_ret, mock_update):
|
||||
def test_bulk_cmd_script(
|
||||
self, salt_batch_async, bulk_command, bulk_script, mock_update
|
||||
):
|
||||
url = "/agents/bulk/"
|
||||
|
||||
mock_ret.return_value = "ok"
|
||||
|
||||
payload = {
|
||||
"mode": "command",
|
||||
"target": "agents",
|
||||
|
@ -550,6 +550,7 @@ class TestAgentViews(TacticalTestCase):
|
|||
}
|
||||
|
||||
r = self.client.post(url, payload, format="json")
|
||||
bulk_command.assert_called_with([self.agent.pk], "gpupdate /force", "cmd", 300)
|
||||
self.assertEqual(r.status_code, 200)
|
||||
|
||||
payload = {
|
||||
|
@ -581,6 +582,7 @@ class TestAgentViews(TacticalTestCase):
|
|||
|
||||
r = self.client.post(url, payload, format="json")
|
||||
self.assertEqual(r.status_code, 200)
|
||||
bulk_command.assert_called_with([self.agent.pk], "gpupdate /force", "cmd", 300)
|
||||
|
||||
payload = {
|
||||
"mode": "command",
|
||||
|
@ -597,12 +599,7 @@ class TestAgentViews(TacticalTestCase):
|
|||
|
||||
r = self.client.post(url, payload, format="json")
|
||||
self.assertEqual(r.status_code, 200)
|
||||
|
||||
mock_ret.return_value = "timeout"
|
||||
payload["client"] = self.agent.client.id
|
||||
payload["site"] = self.agent.site.id
|
||||
r = self.client.post(url, payload, format="json")
|
||||
self.assertEqual(r.status_code, 400)
|
||||
bulk_command.assert_called_with([self.agent.pk], "gpupdate /force", "cmd", 300)
|
||||
|
||||
payload = {
|
||||
"mode": "scan",
|
||||
|
@ -613,9 +610,8 @@ class TestAgentViews(TacticalTestCase):
|
|||
self.agent.pk,
|
||||
],
|
||||
}
|
||||
mock_ret.return_value = "ok"
|
||||
r = self.client.post(url, payload, format="json")
|
||||
mock_update.assert_called_once()
|
||||
mock_update.assert_called_with(minions=[self.agent.salt_id])
|
||||
self.assertEqual(r.status_code, 200)
|
||||
|
||||
payload = {
|
||||
|
@ -627,6 +623,7 @@ class TestAgentViews(TacticalTestCase):
|
|||
self.agent.pk,
|
||||
],
|
||||
}
|
||||
salt_batch_async.return_value = "ok"
|
||||
r = self.client.post(url, payload, format="json")
|
||||
self.assertEqual(r.status_code, 200)
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ from winupdate.serializers import WinUpdatePolicySerializer
|
|||
|
||||
from .tasks import uninstall_agent_task, send_agent_update_task
|
||||
from winupdate.tasks import bulk_check_for_updates_task
|
||||
from scripts.tasks import run_bulk_script_task
|
||||
from scripts.tasks import handle_bulk_command_task, handle_bulk_script_task
|
||||
|
||||
from tacticalrmm.utils import notify_error, reload_nats
|
||||
|
||||
|
@ -760,73 +760,44 @@ def bulk(request):
|
|||
return notify_error("Must select at least 1 agent")
|
||||
|
||||
if request.data["target"] == "client":
|
||||
agents = Agent.objects.filter(site__client_id=request.data["client"])
|
||||
q = Agent.objects.filter(site__client_id=request.data["client"])
|
||||
elif request.data["target"] == "site":
|
||||
agents = Agent.objects.filter(site_id=request.data["site"])
|
||||
q = Agent.objects.filter(site_id=request.data["site"])
|
||||
elif request.data["target"] == "agents":
|
||||
agents = Agent.objects.filter(pk__in=request.data["agentPKs"])
|
||||
q = Agent.objects.filter(pk__in=request.data["agentPKs"])
|
||||
elif request.data["target"] == "all":
|
||||
agents = Agent.objects.all()
|
||||
q = Agent.objects.all()
|
||||
else:
|
||||
return notify_error("Something went wrong")
|
||||
|
||||
minions = [agent.salt_id for agent in agents]
|
||||
minions = [agent.salt_id for agent in q]
|
||||
agents = [agent.pk for agent in q]
|
||||
|
||||
AuditLog.audit_bulk_action(request.user, request.data["mode"], request.data)
|
||||
|
||||
if request.data["mode"] == "command":
|
||||
r = Agent.salt_batch_async(
|
||||
minions=minions,
|
||||
func="cmd.run_bg",
|
||||
kwargs={
|
||||
"cmd": request.data["cmd"],
|
||||
"shell": request.data["shell"],
|
||||
"timeout": request.data["timeout"],
|
||||
},
|
||||
handle_bulk_command_task.delay(
|
||||
agents, request.data["cmd"], request.data["shell"], request.data["timeout"]
|
||||
)
|
||||
if r == "timeout":
|
||||
return notify_error("Salt API not running")
|
||||
return Response(f"Command will now be run on {len(minions)} agents")
|
||||
return Response(f"Command will now be run on {len(agents)} agents")
|
||||
|
||||
elif request.data["mode"] == "script":
|
||||
script = get_object_or_404(Script, pk=request.data["scriptPK"])
|
||||
|
||||
if script.shell == "python":
|
||||
r = Agent.salt_batch_async(
|
||||
minions=minions,
|
||||
func="win_agent.run_script",
|
||||
kwargs={
|
||||
"filepath": script.filepath,
|
||||
"filename": script.filename,
|
||||
"shell": script.shell,
|
||||
"timeout": request.data["timeout"],
|
||||
"args": request.data["args"],
|
||||
"bg": True,
|
||||
},
|
||||
)
|
||||
if r == "timeout":
|
||||
return notify_error("Salt API not running")
|
||||
else:
|
||||
data = {
|
||||
"minions": minions,
|
||||
"scriptpk": script.pk,
|
||||
"timeout": request.data["timeout"],
|
||||
"args": request.data["args"],
|
||||
}
|
||||
run_bulk_script_task.delay(data)
|
||||
|
||||
return Response(f"{script.name} will now be run on {len(minions)} agents")
|
||||
handle_bulk_script_task.delay(
|
||||
script.pk, agents, request.data["args"], request.data["timeout"]
|
||||
)
|
||||
return Response(f"{script.name} will now be run on {len(agents)} agents")
|
||||
|
||||
elif request.data["mode"] == "install":
|
||||
r = Agent.salt_batch_async(minions=minions, func="win_agent.install_updates")
|
||||
if r == "timeout":
|
||||
return notify_error("Salt API not running")
|
||||
return Response(
|
||||
f"Pending updates will now be installed on {len(minions)} agents"
|
||||
f"Pending updates will now be installed on {len(agents)} agents"
|
||||
)
|
||||
elif request.data["mode"] == "scan":
|
||||
bulk_check_for_updates_task.delay(minions=minions)
|
||||
return Response(f"Patch status scan will now run on {len(minions)} agents")
|
||||
return Response(f"Patch status scan will now run on {len(agents)} agents")
|
||||
|
||||
return notify_error("Something went wrong")
|
||||
|
||||
|
|
|
@ -1,21 +1,74 @@
|
|||
import asyncio
|
||||
|
||||
from tacticalrmm.celery import app
|
||||
from agents.models import Agent
|
||||
from .models import Script
|
||||
from scripts.models import Script
|
||||
|
||||
|
||||
@app.task
|
||||
def run_bulk_script_task(data):
|
||||
# for powershell and batch scripts only, workaround for salt bg script bug
|
||||
script = Script.objects.get(pk=data["scriptpk"])
|
||||
def handle_bulk_command_task(agentpks, cmd, shell, timeout):
|
||||
agents = Agent.objects.filter(pk__in=agentpks)
|
||||
|
||||
Agent.salt_batch_async(
|
||||
minions=data["minions"],
|
||||
func="win_agent.run_script",
|
||||
kwargs={
|
||||
"filepath": script.filepath,
|
||||
"filename": script.filename,
|
||||
"shell": script.shell,
|
||||
"timeout": data["timeout"],
|
||||
"args": data["args"],
|
||||
},
|
||||
)
|
||||
agents_nats = [agent for agent in agents if agent.has_nats]
|
||||
agents_salt = [agent for agent in agents if not agent.has_nats]
|
||||
minions = [agent.salt_id for agent in agents_salt]
|
||||
|
||||
if minions:
|
||||
Agent.salt_batch_async(
|
||||
minions=minions,
|
||||
func="cmd.run_bg",
|
||||
kwargs={
|
||||
"cmd": cmd,
|
||||
"shell": shell,
|
||||
"timeout": timeout,
|
||||
},
|
||||
)
|
||||
|
||||
if agents_nats:
|
||||
nats_data = {
|
||||
"func": "rawcmd",
|
||||
"timeout": timeout,
|
||||
"payload": {
|
||||
"command": cmd,
|
||||
"shell": shell,
|
||||
},
|
||||
}
|
||||
for agent in agents_nats:
|
||||
asyncio.run(agent.nats_cmd(nats_data, wait=False))
|
||||
|
||||
|
||||
@app.task
|
||||
def handle_bulk_script_task(scriptpk, agentpks, args, timeout):
|
||||
script = Script.objects.get(pk=scriptpk)
|
||||
agents = Agent.objects.filter(pk__in=agentpks)
|
||||
|
||||
agents_nats = [agent for agent in agents if agent.has_nats]
|
||||
agents_salt = [agent for agent in agents if not agent.has_nats]
|
||||
minions = [agent.salt_id for agent in agents_salt]
|
||||
|
||||
if minions:
|
||||
Agent.salt_batch_async(
|
||||
minions=minions,
|
||||
func="win_agent.run_script",
|
||||
kwargs={
|
||||
"filepath": script.filepath,
|
||||
"filename": script.filename,
|
||||
"shell": script.shell,
|
||||
"timeout": timeout,
|
||||
"args": args,
|
||||
"bg": True if script.shell == "python" else False, # salt bg script bug
|
||||
},
|
||||
)
|
||||
|
||||
if agents_nats:
|
||||
nats_data = {
|
||||
"func": "runscript",
|
||||
"timeout": timeout,
|
||||
"script_args": args,
|
||||
"payload": {
|
||||
"code": script.code,
|
||||
"shell": script.shell,
|
||||
},
|
||||
}
|
||||
for agent in agents_nats:
|
||||
asyncio.run(agent.nats_cmd(nats_data, wait=False))
|
||||
|
|
Loading…
Reference in New Issue