diff --git a/api/tacticalrmm/ee/reporting/urls.py b/api/tacticalrmm/ee/reporting/urls.py index 1c5f3680..a406459a 100644 --- a/api/tacticalrmm/ee/reporting/urls.py +++ b/api/tacticalrmm/ee/reporting/urls.py @@ -17,6 +17,8 @@ urlpatterns = [ path("templates/preview/", views.GenerateReportPreview.as_view()), path("templates/preview/analysis/", views.GetAllowedValues.as_view()), path("templates/import/", views.ImportReportTemplate.as_view()), + # shared templates + path("templates/shared/", views.SharedTemplatesRepo.as_view()), # report assets path("assets/", views.GetReportAssets.as_view()), path("assets/all/", views.GetAllAssets.as_view()), diff --git a/api/tacticalrmm/ee/reporting/utils.py b/api/tacticalrmm/ee/reporting/utils.py index 177617e6..bea10874 100644 --- a/api/tacticalrmm/ee/reporting/utils.py +++ b/api/tacticalrmm/ee/reporting/utils.py @@ -20,7 +20,7 @@ from enum import Enum from .constants import REPORTING_MODELS from .markdown.config import Markdown from .models import ReportAsset, ReportHTMLTemplate, ReportTemplate, ReportDataQuery - +from rest_framework.serializers import ValidationError # regex for db data replacement # will return 3 groups of matches in a tuple when uses with re.findall @@ -612,3 +612,101 @@ def generate_chart( return cast(str, fig.to_html(full_html=False, include_plotlyjs="cdn")) elif format == "image": return cast(str, fig.to_image(format="svg").decode("utf-8")) + + +# import report functions +def _import_base_template( + base_template_data: Optional[Dict[str, Any]] = None, + overwrite: bool = False, +) -> Optional[int]: + if base_template_data: + # Check name conflict and modify name if necessary + name = base_template_data.get("name") + html = base_template_data.get("html") + + if not name: + raise ValidationError("base_template is missing 'name' key") + if not html: + raise ValidationError("base_template is missing 'html' field") + + if ReportHTMLTemplate.objects.filter(name=name).exists(): + base_template = ReportHTMLTemplate.objects.filter(name=name).get() + if overwrite: + base_template.html = html + base_template.save() + else: + name += f"_{_generate_random_string()}" + base_template = ReportHTMLTemplate.objects.create(name=name, html=html) + else: + base_template = ReportHTMLTemplate.objects.create(name=name, html=html) + + base_template.refresh_from_db() + return base_template.id + return None + + +def _import_report_template( + report_template_data: Dict[str, Any], + base_template_id: Optional[int] = None, + overwrite: bool = False, +) -> "ReportTemplate": + if report_template_data: + name = report_template_data.pop("name", None) + template_md = report_template_data.get("template_md") + + if not name: + raise ValidationError("template requires a 'name' key") + if not template_md: + raise ValidationError("template requires a 'template_md' field") + + if ReportTemplate.objects.filter(name=name).exists(): + report_template = ReportTemplate.objects.filter(name=name).get() + if overwrite: + for key, value in report_template_data.items(): + setattr(report_template, key, value) + + report_template.save() + else: + name += f"_{_generate_random_string()}" + report_template = ReportTemplate.objects.create( + name=name, + template_html_id=base_template_id, + **report_template_data, + ) + else: + report_template = ReportTemplate.objects.create( + name=name, template_html_id=base_template_id, **report_template_data + ) + report_template.refresh_from_db() + return report_template + else: + raise ValidationError("'template' key is required in input") + + +def _import_assets(assets: List[Dict[str, Any]]) -> None: + from django.core.files import File + import io + import os + from .storage import report_assets_fs + + if isinstance(assets, list): + for asset in assets: + parent_folder = report_assets_fs.getreldir(path=asset["name"]) + path = report_assets_fs.get_available_name( + os.path.join(parent_folder, asset["name"]) + ) + asset_obj = ReportAsset( + id=asset["id"], + file=File( + io.BytesIO(decode_base64_asset(asset["file"])), + name=path, + ), + ) + asset_obj.save() + + +def _generate_random_string(length: int = 6) -> str: + import random + import string + + return "".join(random.choice(string.ascii_lowercase) for i in range(length)) diff --git a/api/tacticalrmm/ee/reporting/views.py b/api/tacticalrmm/ee/reporting/views.py index 30d02d6c..e4e7a179 100644 --- a/api/tacticalrmm/ee/reporting/views.py +++ b/api/tacticalrmm/ee/reporting/views.py @@ -48,6 +48,9 @@ from .utils import ( generate_pdf, normalize_asset_url, prep_variables_for_template, + _import_report_template, + _import_assets, + _import_base_template, ) @@ -253,7 +256,6 @@ class ExportReportTemplate(APIView): base_template = { "name": template_html.name, "html": template_html.html, - "uuid": template_html.uuid, } assets = base64_encode_assets( @@ -271,7 +273,6 @@ class ExportReportTemplate(APIView): "type": template.type, "depends_on": template.depends_on, "template_variables": template.template_variables, - "uuid": template.uuid, }, "assets": assets, } @@ -279,8 +280,6 @@ class ExportReportTemplate(APIView): class ImportReportTemplate(APIView): - did_overwrite = False - @transaction.atomic def post(self, request: Request) -> Response: try: @@ -288,128 +287,25 @@ class ImportReportTemplate(APIView): overwrite = request.data.get("overwrite", False) # import base template if exists - base_template_id = self._import_base_template( + base_template_id = _import_base_template( template_obj.get("base_template"), overwrite ) # import template if exists - report_template = self._import_report_template( + report_template = _import_report_template( template_obj.get("template"), base_template_id, overwrite ) # import assets if exists - self._import_assets(template_obj.get("assets")) + _import_assets(template_obj.get("assets")) - return Response( - { - "template": ReportTemplateSerializer(report_template).data, - "overwrite": self.did_overwrite, - } - ) + return Response(ReportTemplateSerializer(report_template).data) except Exception as e: # rollback db transaction if any exception occurs transaction.set_rollback(True) return notify_error(str(e)) - def _import_base_template( - self, - base_template_data: Optional[Dict[str, Any]] = None, - overwrite: bool = False, - ) -> Optional[int]: - if base_template_data: - # Check name conflict and modify name if necessary - name = base_template_data.get("name") - html = base_template_data.get("html") - - if not name: - raise ValidationError("base_template is missing 'name' key") - if not html: - raise ValidationError("base_template is missing 'html' field") - - if ReportHTMLTemplate.objects.filter(name=name).exists(): - base_template = ReportHTMLTemplate.objects.filter(name=name).get() - if overwrite: - base_template.html = html - base_template.save() - else: - name += f"_{self._generate_random_string()}" - base_template = ReportHTMLTemplate.objects.create( - name=name, html=html - ) - else: - base_template = ReportHTMLTemplate.objects.create(name=name, html=html) - - base_template.refresh_from_db() - return base_template.id - return None - - def _import_report_template( - self, - report_template_data: Dict[str, Any], - base_template_id: Optional[int] = None, - overwrite: bool = False, - ) -> "ReportTemplate": - if report_template_data: - name = report_template_data.pop("name", None) - template_md = report_template_data.get("template_md") - - if not name: - raise ValidationError("template requires a 'name' key") - if not template_md: - raise ValidationError("template requires a 'template_md' field") - - if ReportTemplate.objects.filter(name=name).exists(): - report_template = ReportTemplate.objects.filter(name=name).get() - if overwrite: - self.did_overwrite = True - for key, value in report_template_data.items(): - setattr(report_template, key, value) - - report_template.save() - else: - name += f"_{self._generate_random_string()}" - report_template = ReportTemplate.objects.create( - name=name, - template_html_id=base_template_id, - **report_template_data, - ) - else: - report_template = ReportTemplate.objects.create( - name=name, template_html_id=base_template_id, **report_template_data - ) - report_template.refresh_from_db() - return report_template - else: - raise ValidationError("'template' key is required in input") - - def _import_assets(self, assets: List[Dict[str, Any]]) -> None: - from django.core.files import File - import io - from .storage import report_assets_fs - - if isinstance(assets, list): - for asset in assets: - parent_folder = report_assets_fs.getreldir(path=asset["name"]) - path = report_assets_fs.get_available_name( - os.path.join(parent_folder, asset["name"]) - ) - asset_obj = ReportAsset( - id=asset["id"], - file=File( - io.BytesIO(decode_base64_asset(asset["file"])), - name=path, - ), - ) - asset_obj.save() - - @staticmethod - def _generate_random_string(length: int = 6) -> str: - import random - import string - - return "".join(random.choice(string.ascii_lowercase) for i in range(length)) - class GetAllowedValues(APIView): def post(self, request: Request) -> Response: @@ -457,6 +353,62 @@ class GetAllowedValues(APIView): return items +class SharedTemplatesRepo(APIView): + def get(self, request: Request) -> Response: + import requests + + try: + url = f"https://api.github.com/repos/amidaware/private-scripts/contents/Reporting%20Templates/" + headers = {"Authorization": f"Bearer {djangosettings.GH_TOKEN}"} + response = requests.get(url, headers=headers) + files = response.json() + return Response( + [ + {"name": file["name"], "url": file["download_url"]} + for file in files + if file["download_url"] + ] + ) + except: + return notify_error("Unable to connect to repo") + + @transaction.atomic + def post(self, request: Request) -> Response: + import requests + + overwrite = request.data.get("overwrite", False) + templates = request.data.get("templates", None) + + if not templates: + return notify_error("No templates to import") + + headers = {"Authorization": f"Bearer {djangosettings.GH_TOKEN}"} + # try: + for template in templates: + response = requests.get(template["url"], headers=headers) + template_obj = response.json() + + # import base template if exists + base_template_id = _import_base_template( + template_obj.get("base_template"), overwrite + ) + + # import template if exists + report_template = _import_report_template( + template_obj.get("template"), base_template_id, overwrite + ) + + # import assets if exists + _import_assets(template_obj.get("assets")) + + return Response() + + # except Exception as e: + # # rollback db transaction if any exception occurs + # transaction.set_rollback(True) + # return notify_error(str(e)) + + class GetReportAssets(APIView): def get(self, request: Request) -> Response: path = request.query_params.get("path", "").lstrip("/")