fix tests and more typing

This commit is contained in:
sadnub 2022-04-08 23:13:45 -04:00
parent 62e2b5230c
commit 82624d6657
15 changed files with 250 additions and 145 deletions

View File

@ -416,7 +416,7 @@ class Agent(BaseAuditModel):
def run_script(
self,
scriptpk: int,
args: list[str] = [],
args: List[str] = [],
timeout: int = 120,
full: bool = False,
wait: bool = False,
@ -503,6 +503,7 @@ class Agent(BaseAuditModel):
# returns agent policy merged with a client or site specific policy
def get_patch_policy(self) -> "WinUpdatePolicy":
from winupdate.models import WinUpdatePolicy
# check if site has a patch policy and if so use it
patch_policy = None
@ -510,11 +511,11 @@ class Agent(BaseAuditModel):
agent_policy = self.winupdatepolicy.first()
if not agent_policy:
raise WinUpdatePolicy.DoesNotExist
agent_policy = WinUpdatePolicy.objects.create(agent=self)
policies = self.get_agent_policies()
processed_policies: "List[int]" = list()
processed_policies: List[int] = list()
for _, policy in policies.items():
if (
policy
@ -664,32 +665,43 @@ class Agent(BaseAuditModel):
def get_checks_from_policies(self) -> "List[Check]":
from automation.models import Policy
cached_checks = cache.get(f"site_{self.site.id}_checks")
cache_checks = False
if not self.policy and not self.agentchecks.exists():
cached_checks = cache.get(f"site_{self.site.id}_checks")
if cached_checks and isinstance(cached_checks, list):
return cached_checks
else:
# clear agent checks that have overridden_by_policy set
self.agentchecks.update(overridden_by_policy=False) # type: ignore
if cached_checks and isinstance(cached_checks, list):
return cached_checks
else:
cached_checks = True
# get agent checks based on policies
checks = Policy.get_policy_checks(self)
# clear agent checks that have overridden_by_policy set
self.agentchecks.update(overridden_by_policy=False)
# get agent checks based on policies
checks = Policy.get_policy_checks(self)
if cache_checks:
cache.set(f"site_{self.site.id}_checks", checks, 300)
return checks
return checks
def get_tasks_from_policies(self) -> "List[AutomatedTask]":
from automation.models import Policy
cached_tasks = cache.get(f"site_{self.site.id}_tasks")
cache_tasks = False
if not self.policy:
cached_tasks = cache.get(f"site_{self.site.id}_tasks")
if cached_tasks and isinstance(cached_tasks, list):
return cached_tasks
else:
# get agent tasks based on policies
tasks = Policy.get_policy_tasks(self)
if cached_tasks and isinstance(cached_tasks, list):
return cached_tasks
else:
cached_tasks = True
# get agent tasks based on policies
tasks = Policy.get_policy_tasks(self)
if cache_tasks:
cache.set(f"site_{self.site.id}_tasks", tasks, 300)
return tasks
return tasks
def _do_nats_debug(self, agent, message):
DebugLog.error(agent=agent, log_type="agent_issues", message=message)
@ -734,16 +746,16 @@ class Agent(BaseAuditModel):
await nc.close()
@staticmethod
def serialize(class_name: "Agent") -> Dict[str, Any]:
def serialize(agent: "Agent") -> Dict[str, Any]:
# serializes the agent and returns json
from .serializers import AgentAuditSerializer
return AgentAuditSerializer(class_name).data
return AgentAuditSerializer(agent).data
def delete_superseded_updates(self) -> None:
try:
pks = [] # list of pks to delete
kbs = list(self.winupdates.values_list("kb", flat=True)) # type: ignore
kbs = list(self.winupdates.values_list("kb", flat=True))
d = Counter(kbs)
dupes = [k for k, v in d.items() if v > 1]

View File

@ -0,0 +1,20 @@
# Generated by Django 4.0.3 on 2022-04-09 02:58
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('agents', '0047_alter_agent_plat_alter_agent_site'),
('alerts', '0011_alter_alert_agent'),
]
operations = [
migrations.AlterField(
model_name='alert',
name='agent',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='agent', to='agents.agent'),
),
]

View File

@ -39,6 +39,8 @@ class Alert(models.Model):
"agents.Agent",
related_name="agent",
on_delete=models.CASCADE,
null=True,
blank=True,
)
assigned_check = models.ForeignKey(
"checks.Check",
@ -82,7 +84,7 @@ class Alert(models.Model):
max_length=100, null=True, blank=True
)
def __str__(self):
def __str__(self) -> str:
return self.message
@property
@ -149,16 +151,16 @@ class Alert(models.Model):
@classmethod
def create_or_return_check_alert(
cls,
check: Check,
agent: Optional[Agent] = None,
check: "Check",
agent: "Agent",
alert_severity: Optional[str] = None,
skip_create: bool = False,
) -> Optional[Alert]:
) -> "Optional[Alert]":
# need to pass agent if the check is a policy
if not cls.objects.filter(
assigned_check=check,
agent=agent if check.policy else check.agent,
agent=agent,
resolved=False,
).exists():
if skip_create:
@ -168,13 +170,13 @@ class Alert(models.Model):
Alert,
cls.objects.create(
assigned_check=check,
agent=agent if check.policy else check.agent,
agent=agent,
alert_type="check",
severity=check.alert_severity
if check.check_type
not in ["memory", "cpuload", "diskspace", "script"]
else alert_severity,
message=f"{agent.hostname if agent else check.agent.hostname} has a {check.check_type} check: {check.readable_desc} that failed.",
message=f"{agent.hostname} has a {check.check_type} check: {check.readable_desc} that failed.",
hidden=True,
),
)
@ -184,14 +186,14 @@ class Alert(models.Model):
Alert,
cls.objects.get(
assigned_check=check,
agent=agent if check.policy else check.agent,
agent=agent,
resolved=False,
),
)
except cls.MultipleObjectsReturned:
alerts = cls.objects.filter(
assigned_check=check,
agent=agent if check.policy else check.agent,
agent=agent,
resolved=False,
)
last_alert = cast(Alert, alerts.last())
@ -208,41 +210,48 @@ class Alert(models.Model):
@classmethod
def create_or_return_task_alert(
cls,
task: AutomatedTask,
agent: Optional[Agent] = None,
task: "AutomatedTask",
agent: "Agent",
skip_create: bool = False,
) -> Optional[Alert]:
) -> "Optional[Alert]":
if not cls.objects.filter(
assigned_task=task,
agent=agent if task.policy else task.agent,
agent=agent,
resolved=False,
).exists():
if skip_create:
return None
return cls.objects.create(
assigned_task=task,
agent=agent if task.policy else task.agent,
alert_type="task",
severity=task.alert_severity,
message=f"{agent.hostname if agent else task.agent.hostname } has task: {task.name} that failed.",
hidden=True,
return cast(
Alert,
cls.objects.create(
assigned_task=task,
agent=agent,
alert_type="task",
severity=task.alert_severity,
message=f"{agent.hostname} has task: {task.name} that failed.",
hidden=True,
),
)
else:
try:
return cls.objects.get(
assigned_task=task,
agent=agent if task.policy else task.agent,
resolved=False,
return cast(
Alert,
cls.objects.get(
assigned_task=task,
agent=agent,
resolved=False,
),
)
except cls.MultipleObjectsReturned:
alerts = cls.objects.filter(
assigned_task=task,
agent=agent if task.policy else task.agent,
agent=agent,
resolved=False,
)
last_alert = cast(cls, alerts.last())
last_alert = cast(Alert, alerts.last())
# cycle through other alerts and resolve
for alert in alerts:

View File

@ -1,10 +1,13 @@
from django.shortcuts import get_object_or_404
from rest_framework import permissions
from typing import TYPE_CHECKING
from tacticalrmm.permissions import _has_perm, _has_perm_on_agent
if TYPE_CHECKING:
from accounts.models import User
def _has_perm_on_alert(user, id: int) -> bool:
def _has_perm_on_alert(user: "User", id: int) -> bool:
from alerts.models import Alert
role = user.role
@ -19,10 +22,6 @@ def _has_perm_on_alert(user, id: int) -> bool:
if alert.agent:
agent_id = alert.agent.agent_id
elif alert.assigned_check:
agent_id = alert.assigned_check.agent.agent_id
elif alert.assigned_task:
agent_id = alert.assigned_task.agent.agent_id
else:
return True

View File

@ -1506,11 +1506,16 @@ class TestAlertPermissions(TacticalTestCase):
checks = baker.make("checks.Check", agent=cycle(agents), _quantity=3)
tasks = baker.make("autotasks.AutomatedTask", agent=cycle(agents), _quantity=3)
baker.make(
"alerts.Alert", alert_type="task", assigned_task=cycle(tasks), _quantity=3
"alerts.Alert",
alert_type="task",
agent=cycle(agents),
assigned_task=cycle(tasks),
_quantity=3,
)
baker.make(
"alerts.Alert",
alert_type="check",
agent=cycle(agents),
assigned_check=cycle(checks),
_quantity=3,
)
@ -1560,11 +1565,16 @@ class TestAlertPermissions(TacticalTestCase):
checks = baker.make("checks.Check", agent=cycle(agents), _quantity=3)
tasks = baker.make("autotasks.AutomatedTask", agent=cycle(agents), _quantity=3)
alert_tasks = baker.make(
"alerts.Alert", alert_type="task", assigned_task=cycle(tasks), _quantity=3
"alerts.Alert",
alert_type="task",
agent=cycle(agents),
assigned_task=cycle(tasks),
_quantity=3,
)
alert_checks = baker.make(
"alerts.Alert",
alert_type="check",
agent=cycle(agents),
assigned_check=cycle(checks),
_quantity=3,
)
@ -1644,72 +1654,75 @@ class TestAlertPermissions(TacticalTestCase):
for url in unauthorized_urls:
self.check_authorized(method, url)
def test_handling_multiple_availability_alerts_returned(self):
agent = baker.make_recipe("agents.agent")
alerts = baker.make(
"alerts.Alert",
alert_type="availability",
agent=agent,
resolved=False,
_quantity=3,
)
def test_handling_multiple_availability_alerts_returned(self):
agent = baker.make_recipe("agents.agent")
alerts = baker.make(
"alerts.Alerts",
alert_type="availability",
agent=agent,
resolved=False,
_quantity=3,
)
alert = Alert.create_or_return_availability_alert(agent, skip_create=True)
alert = Alert.create_or_return_availability_alert(agent, skip_create=True)
# make sure last alert is returned
self.assertEqual(alert, alerts[-1])
# make sure last alert is returned
self.assertEqual(alert, alerts[-1])
# make sure only 1 alert is not resolved
self.assertEqual(
Alert.objects.filter(
alert_type="availability", agent=agent, resolved=False
).count(),
1,
)
# make sure only 1 alert is not resolved
self.assertEqual(
Alert.objects.filter(
alert_type="availability", agent=agent, resolved=False
).count(),
1,
)
def test_handling_multiple_check_alerts_returned(self):
agent = baker.make_recipe("agents.agent")
check = baker.make_recipe("checks.diskspace_check", agent=agent)
alerts = baker.make(
"alerts.Alert",
alert_type="check",
assigned_check=check,
agent=agent,
resolved=False,
_quantity=3,
)
alert = Alert.create_or_return_check_alert(check, agent=agent, skip_create=True)
def test_handling_multiple_check_alerts_returned(self):
agent = baker.make_recipe("agents.agent")
check = baker.make_recipe("checks.diskspace_check", agent=agent)
alerts = baker.make(
"alerts.Alerts",
alert_type="check",
assigned_check=check,
agent=agent,
resolved=False,
_quantity=3,
)
# make sure last alert is returned
self.assertEqual(alert, alerts[-1])
alert = Alert.create_or_return_check_alert(check, skip_create=True)
# make sure only 1 alert is not resolved
self.assertEqual(
Alert.objects.filter(
alert_type="check", agent=agent, resolved=False
).count(),
1,
)
# make sure last alert is returned
self.assertEqual(alert, alerts[-1])
def test_handling_multiple_task_alerts_returned(self):
agent = baker.make_recipe("agents.agent")
task = baker.make("autotasks.AutomatedTask", agent=agent)
alerts = baker.make(
"alerts.Alert",
alert_type="task",
assigned_task=task,
agent=agent,
resolved=False,
_quantity=3,
)
# make sure only 1 alert is not resolved
self.assertEqual(
Alert.objects.filter(alert_type="check", agent=agent, resolved=False).count(), 1
)
alert = Alert.create_or_return_task_alert(task, agent=agent, skip_create=True)
# make sure last alert is returned
self.assertEqual(alert, alerts[-1])
def test_handling_multiple_task_alerts_returned(self):
agent = baker.make_recipe("agents.agent")
task = baker.make_recipe("autotasks.AutomatedTask", agent=agent)
alerts = baker.make(
"alerts.Alerts",
alert_type="check",
assigned_task=task,
agent=agent,
resolved=False,
_quantity=3,
)
alert = Alert.create_or_return_task_alert(task, skip_create=True)
# make sure last alert is returned
self.assertEqual(alert, alerts[-1])
# make sure only 1 alert is not resolved
self.assertEqual(
Alert.objects.filter(alert_type="task", agent=agent, resolved=False).count(), 1
)
# make sure only 1 alert is not resolved
self.assertEqual(
Alert.objects.filter(
alert_type="task", agent=agent, resolved=False
).count(),
1,
)

View File

@ -312,7 +312,7 @@ class TaskRunner(APIView):
task.save(update_fields=["status"])
if status == "passing":
if Alert.create_or_return_task_alert(task, skip_create=True):
if Alert.create_or_return_task_alert(task, agent=agent, skip_create=True):
Alert.handle_alert_resolve(task_result)
else:
Alert.handle_alert_failure(task_result)

View File

@ -249,7 +249,7 @@ class Policy(BaseAuditModel):
cpuload_checks: "List[Check]" = list()
memory_checks: "List[Check]" = list()
overridden_checks = list()
overridden_checks: List[int] = list()
# Loop over checks in with enforced policies first, then non-enforced policies
for check in enforced_checks + agent_checks + policy_checks:

View File

@ -2,11 +2,10 @@ import asyncio
import random
import string
import pytz
from typing import TYPE_CHECKING, List, Dict, Optional, Union
from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union
from alerts.models import SEVERITY_CHOICES
from django.core.validators import MaxValueValidator, MinValueValidator
from django.contrib.postgres.fields import ArrayField
from django.utils import timezone as djangotime
from django.db import models
from django.db.models.fields import DateTimeField
@ -17,7 +16,6 @@ from core.utils import get_core_settings
if TYPE_CHECKING:
from automation.models import Policy
from autotasks.models import AutomatedTask
from alerts.models import Alert, AlertTemplate
from agents.models import Agent
from checks.models import Check
@ -146,12 +144,12 @@ class AutomatedTask(BaseAuditModel):
managed_by_policy = models.BooleanField(default=False)
# non-database property
task_result: "Union[TaskResult, Dict]" = {}
task_result: "Union[TaskResult, Dict[None, None]]" = {}
def __str__(self):
def __str__(self) -> str:
return self.name
def save(self, *args, **kwargs):
def save(self, *args, **kwargs) -> None:
# get old task if exists
old_task = AutomatedTask.objects.get(pk=self.pk) if self.pk else None
super(AutomatedTask, self).save(old_model=old_task, *args, **kwargs)
@ -170,7 +168,7 @@ class AutomatedTask(BaseAuditModel):
)
@property
def schedule(self):
def schedule(self) -> Optional[str]:
if self.task_type == "manual":
return "Manual"
elif self.task_type == "checkfailure":
@ -225,7 +223,7 @@ class AutomatedTask(BaseAuditModel):
]
@staticmethod
def generate_task_name():
def generate_task_name() -> str:
chars = string.ascii_letters
return "TacticalRMM_" + "".join(random.choice(chars) for i in range(35))
@ -284,7 +282,7 @@ class AutomatedTask(BaseAuditModel):
# agent version >= 1.8.0
def generate_nats_task_payload(
self, agent: "Optional[Agent]" = None, editing: bool = False
) -> Dict:
) -> Dict[str, Any]:
task = {
"pk": self.pk,
"type": "rmm",

View File

@ -13,7 +13,6 @@ from core.utils import get_core_settings
if TYPE_CHECKING:
from alerts.models import Alert, AlertTemplate
from automation.models import Policy
from checks.models import CheckResult
CHECK_TYPE_CHOICES = [
("diskspace", "Disk Space Check"),

View File

@ -0,0 +1,54 @@
# Generated by Django 4.0.3 on 2022-04-08 03:16
import django.contrib.postgres.fields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0031_coresettings_date_format'),
]
operations = [
migrations.AlterField(
model_name='coresettings',
name='email_alert_recipients',
field=django.contrib.postgres.fields.ArrayField(base_field=models.EmailField(blank=True, max_length=254, null=True), blank=True, default=list, size=None),
),
migrations.AlterField(
model_name='coresettings',
name='sms_alert_recipients',
field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=255, null=True), blank=True, default=list, size=None),
),
migrations.AlterField(
model_name='coresettings',
name='smtp_from_email',
field=models.CharField(blank=True, default='from@example.com', max_length=255),
),
migrations.AlterField(
model_name='coresettings',
name='smtp_host',
field=models.CharField(blank=True, default='smtp.gmail.com', max_length=255),
),
migrations.AlterField(
model_name='coresettings',
name='smtp_host_password',
field=models.CharField(blank=True, default='changeme', max_length=255),
),
migrations.AlterField(
model_name='coresettings',
name='smtp_host_user',
field=models.CharField(blank=True, default='admin@example.com', max_length=255),
),
migrations.AlterField(
model_name='coresettings',
name='smtp_port',
field=models.PositiveIntegerField(blank=True, default=587),
),
migrations.AlterField(
model_name='customfield',
name='name',
field=models.CharField(max_length=30),
),
]

View File

@ -261,7 +261,7 @@ class CustomField(BaseAuditModel):
blank=True,
default=list,
)
name = models.CharField(max_length=30, blank=True)
name = models.CharField(max_length=30)
required = models.BooleanField(blank=True, default=False)
default_value_string = models.TextField(null=True, blank=True)
default_value_bool = models.BooleanField(default=False)

View File

@ -179,7 +179,6 @@ class TestCoreTasks(TacticalTestCase):
serializer = CustomFieldSerializer(custom_fields, many=True)
self.assertEqual(r.status_code, 200)
self.assertEqual(len(r.data), 5)
self.assertEqual(r.data, serializer.data)
self.check_not_authenticated("patch", url)

View File

@ -7,9 +7,9 @@ from tacticalrmm.middleware import get_debug_info, get_username
from tacticalrmm.models import PermissionQuerySet
if TYPE_CHECKING:
from agents.models import Agent
from clients.models import Client, Site
from core.models import URLAction
from agents.models import Agent
def get_debug_level() -> str:
@ -240,6 +240,7 @@ class AuditLog(models.Model):
instance: "Union[Agent, Client, Site]",
debug_info: Dict[Any, Any] = {},
) -> None:
from agents.models import Agent
name = instance.hostname if isinstance(instance, Agent) else instance.name
classname = type(instance).__name__

View File

@ -63,17 +63,9 @@ class PermissionQuerySet(models.QuerySet):
custom_alert_queryset = models.Q()
if can_view_clients:
clients_queryset = (
models.Q(agent__site__client__in=can_view_clients)
| models.Q(assigned_check__agent__site__client__in=can_view_clients)
| models.Q(assigned_task__agent__site__client__in=can_view_clients)
)
clients_queryset = models.Q(agent__site__client__in=can_view_clients)
if can_view_sites:
sites_queryset = (
models.Q(agent__site__in=can_view_sites)
| models.Q(assigned_check__agent__site__in=can_view_sites)
| models.Q(assigned_task__agent__site__in=can_view_sites)
)
sites_queryset = models.Q(agent__site__in=can_view_sites)
if can_view_clients or can_view_sites:
custom_alert_queryset = models.Q(
agent=None, assigned_check=None, assigned_task=None

View File

@ -16,6 +16,12 @@ if TYPE_CHECKING:
from checks.models import Check
from scripts.models import Script
TEST_CACHE = {
"default": {
"BACKEND": "django.core.cache.backends.dummy.DummyCache",
}
}
class TacticalTestCase(TestCase):
client: APIClient
@ -38,7 +44,7 @@ class TacticalTestCase(TestCase):
password=User.objects.make_random_password(60), # type: ignore
)
def setup_client(self):
def setup_client(self) -> None:
self.client = APIClient()
def setup_agent_auth(self, agent: "Agent") -> None:
@ -50,9 +56,10 @@ class TacticalTestCase(TestCase):
# fixes tests waiting 2 minutes for mesh token to appear
@override_settings(
MESH_TOKEN_KEY="41410834b8bb4481446027f87d88ec6f119eb9aa97860366440b778540c7399613f7cabfef4f1aa5c0bd9beae03757e17b2e990e5876b0d9924da59bdf24d3437b3ed1a8593b78d65a72a76c794160d9"
MESH_TOKEN_KEY="41410834b8bb4481446027f87d88ec6f119eb9aa97860366440b778540c7399613f7cabfef4f1aa5c0bd9beae03757e17b2e990e5876b0d9924da59bdf24d3437b3ed1a8593b78d65a72a76c794160d9",
CACHES=TEST_CACHE,
)
def setup_coresettings(self):
def setup_coresettings(self) -> None:
self.coresettings = CoreSettings.objects.create()
def check_not_authenticated(self, method: str, url: str) -> None:
@ -90,7 +97,7 @@ class TacticalTestCase(TestCase):
return checks
def check_not_authorized(
self, method: str, url: str, data: Optional[Dict] = {}
self, method: str, url: str, data: Optional[Dict[Any, Any]] = {}
) -> None:
try:
r = getattr(self.client, method)(url, data, format="json")
@ -98,7 +105,9 @@ class TacticalTestCase(TestCase):
except KeyError:
pass
def check_authorized(self, method: str, url: str, data: Optional[Dict] = {}) -> Any:
def check_authorized(
self, method: str, url: str, data: Optional[Dict[Any, Any]] = {}
) -> Any:
try:
r = getattr(self.client, method)(url, data, format="json")
self.assertNotEqual(r.status_code, 403)
@ -107,7 +116,7 @@ class TacticalTestCase(TestCase):
pass
def check_authorized_superuser(
self, method: str, url: str, data: Optional[Dict] = {}
self, method: str, url: str, data: Optional[Dict[Any, Any]] = {}
) -> Any:
try: