diff --git a/api/tacticalrmm/accounts/management/commands/reset_2fa.py b/api/tacticalrmm/accounts/management/commands/reset_2fa.py index 4a71fb3d..8e659722 100644 --- a/api/tacticalrmm/accounts/management/commands/reset_2fa.py +++ b/api/tacticalrmm/accounts/management/commands/reset_2fa.py @@ -1,10 +1,11 @@ import subprocess import pyotp +from django.conf import settings from django.core.management.base import BaseCommand from accounts.models import User -from tacticalrmm.helpers import get_webdomain +from tacticalrmm.util_settings import get_webdomain class Command(BaseCommand): @@ -26,7 +27,7 @@ class Command(BaseCommand): user.save(update_fields=["totp_key"]) url = pyotp.totp.TOTP(code).provisioning_uri( - username, issuer_name=get_webdomain() + username, issuer_name=get_webdomain(settings.CORS_ORIGIN_WHITELIST[0]) ) subprocess.run(f'qr "{url}"', shell=True) self.stdout.write( diff --git a/api/tacticalrmm/accounts/serializers.py b/api/tacticalrmm/accounts/serializers.py index 78186aef..c6717333 100644 --- a/api/tacticalrmm/accounts/serializers.py +++ b/api/tacticalrmm/accounts/serializers.py @@ -1,11 +1,12 @@ import pyotp +from django.conf import settings from rest_framework.serializers import ( ModelSerializer, ReadOnlyField, SerializerMethodField, ) -from tacticalrmm.helpers import get_webdomain +from tacticalrmm.util_settings import get_webdomain from .models import APIKey, Role, User @@ -63,7 +64,7 @@ class TOTPSetupSerializer(ModelSerializer): def get_qr_url(self, obj): return pyotp.totp.TOTP(obj.totp_key).provisioning_uri( - obj.username, issuer_name=get_webdomain() + obj.username, issuer_name=get_webdomain(settings.CORS_ORIGIN_WHITELIST[0]) ) diff --git a/api/tacticalrmm/core/management/commands/get_config.py b/api/tacticalrmm/core/management/commands/get_config.py index 0a76344f..ba9f26cb 100644 --- a/api/tacticalrmm/core/management/commands/get_config.py +++ b/api/tacticalrmm/core/management/commands/get_config.py @@ -3,7 +3,7 @@ from urllib.parse import urlparse from django.conf import settings from django.core.management.base import BaseCommand -from tacticalrmm.helpers import get_root_domain, get_webdomain +from tacticalrmm.util_settings import get_backend_url, get_root_domain, get_webdomain from tacticalrmm.utils import get_certs @@ -29,8 +29,16 @@ class Command(BaseCommand): self.stdout.write(settings.NATS_SERVER_VER) case "frontend": self.stdout.write(settings.CORS_ORIGIN_WHITELIST[0]) + case "backend_url": + self.stdout.write( + get_backend_url( + settings.ALLOWED_HOSTS[0], + settings.TRMM_PROTO, + settings.TRMM_BACKEND_PORT, + ) + ) case "webdomain": - self.stdout.write(get_webdomain()) + self.stdout.write(get_webdomain(settings.CORS_ORIGIN_WHITELIST[0])) case "djangoadmin": url = f"https://{settings.ALLOWED_HOSTS[0]}/{settings.ADMIN_URL}" self.stdout.write(url) diff --git a/api/tacticalrmm/ee/sso/views.py b/api/tacticalrmm/ee/sso/views.py index 465b80d4..1a8dfe20 100644 --- a/api/tacticalrmm/ee/sso/views.py +++ b/api/tacticalrmm/ee/sso/views.py @@ -7,6 +7,7 @@ For details, see: https://license.tacticalrmm.com/ee import re from allauth.socialaccount.models import SocialAccount, SocialApp +from django.conf import settings from django.contrib.auth import logout from django.core.exceptions import ValidationError from django.shortcuts import get_object_or_404 @@ -16,11 +17,16 @@ from rest_framework import status from rest_framework.authentication import SessionAuthentication from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response -from rest_framework.serializers import ModelSerializer, ReadOnlyField +from rest_framework.serializers import ( + ModelSerializer, + ReadOnlyField, + SerializerMethodField, +) from rest_framework.views import APIView from accounts.permissions import AccountsPerms from logs.models import AuditLog +from tacticalrmm.util_settings import get_backend_url from tacticalrmm.utils import get_core_settings from .permissions import SSOLoginPerms @@ -29,6 +35,15 @@ from .permissions import SSOLoginPerms class SocialAppSerializer(ModelSerializer): server_url = ReadOnlyField(source="settings.server_url") role = ReadOnlyField(source="settings.role") + callback_url = SerializerMethodField() + javascript_origin_url = SerializerMethodField() + + def get_callback_url(self, obj): + backend_url = self.context["backend_url"] + return f"{backend_url}/accounts/oidc/{obj.provider_id}/login/callback/" + + def get_javascript_origin_url(self, obj): + return self.context["frontend_url"] class Meta: model = SocialApp @@ -42,6 +57,8 @@ class SocialAppSerializer(ModelSerializer): "server_url", "settings", "role", + "callback_url", + "javascript_origin_url", ] @@ -49,8 +66,16 @@ class GetAddSSOProvider(APIView): permission_classes = [IsAuthenticated, AccountsPerms] def get(self, request): + ctx = { + "backend_url": get_backend_url( + settings.ALLOWED_HOSTS[0], + settings.TRMM_PROTO, + settings.TRMM_BACKEND_PORT, + ), + "frontend_url": settings.CORS_ORIGIN_WHITELIST[0], + } providers = SocialApp.objects.all() - return Response(SocialAppSerializer(providers, many=True).data) + return Response(SocialAppSerializer(providers, many=True, context=ctx).data) class InputSerializer(ModelSerializer): server_url = ReadOnlyField() diff --git a/api/tacticalrmm/tacticalrmm/helpers.py b/api/tacticalrmm/tacticalrmm/helpers.py index fbd9f796..00ef28f3 100644 --- a/api/tacticalrmm/tacticalrmm/helpers.py +++ b/api/tacticalrmm/tacticalrmm/helpers.py @@ -6,10 +6,8 @@ import secrets import string from pathlib import Path from typing import TYPE_CHECKING, Any, Literal -from urllib.parse import urlparse from zoneinfo import ZoneInfo -import tldextract from cryptography import x509 from django.conf import settings from django.utils import timezone as djangotime @@ -104,16 +102,6 @@ def date_is_in_past(*, datetime_obj: "datetime", agent_tz: str) -> bool: return djangotime.now() > utc_time -def get_webdomain() -> str: - return urlparse(settings.CORS_ORIGIN_WHITELIST[0]).netloc - - -def get_root_domain(subdomain) -> str: - no_fetch_extract = tldextract.TLDExtract(suffix_list_urls=()) - extracted = no_fetch_extract(subdomain) - return f"{extracted.domain}.{extracted.suffix}" - - def rand_range(min: int, max: int) -> float: """ Input is milliseconds. diff --git a/api/tacticalrmm/tacticalrmm/settings.py b/api/tacticalrmm/tacticalrmm/settings.py index 3f9a5e1c..99bfc3de 100644 --- a/api/tacticalrmm/tacticalrmm/settings.py +++ b/api/tacticalrmm/tacticalrmm/settings.py @@ -3,7 +3,8 @@ import sys from contextlib import suppress from datetime import timedelta from pathlib import Path -from tacticalrmm.helpers import get_root_domain, get_webdomain + +from tacticalrmm.util_settings import get_backend_url, get_root_domain, get_webdomain BASE_DIR = Path(__file__).resolve().parent.parent @@ -117,12 +118,12 @@ SWAGGER_ENABLED = False REDIS_HOST = "127.0.0.1" TRMM_LOG_LEVEL = "ERROR" TRMM_LOG_TO = "file" +TRMM_PROTO = "https" +TRMM_BACKEND_PORT = None if not DOCKER_BUILD: ALLOWED_HOSTS = [] CORS_ORIGIN_WHITELIST = [] - TRMM_PROTO = "https" - TRMM_BACKEND_PORT = None with suppress(ImportError): from ee.sso.sso_settings import * # noqa @@ -154,16 +155,14 @@ if "GHACTIONS" in os.environ: if not DOCKER_BUILD: TRMM_ROOT_DOMAIN = get_root_domain(ALLOWED_HOSTS[0]) - frontend_domain = get_webdomain().split(":")[0] + frontend_domain = get_webdomain(CORS_ORIGIN_WHITELIST[0]).split(":")[0] ALLOWED_HOSTS.append(frontend_domain) if DEBUG: ALLOWED_HOSTS.append("*") - backend_url = f"{TRMM_PROTO}://{ALLOWED_HOSTS[0]}" - if TRMM_BACKEND_PORT: - backend_url = f"{backend_url}:{TRMM_BACKEND_PORT}" + backend_url = get_backend_url(ALLOWED_HOSTS[0], TRMM_PROTO, TRMM_BACKEND_PORT) SESSION_COOKIE_DOMAIN = TRMM_ROOT_DOMAIN CSRF_COOKIE_DOMAIN = TRMM_ROOT_DOMAIN diff --git a/api/tacticalrmm/tacticalrmm/util_settings.py b/api/tacticalrmm/tacticalrmm/util_settings.py new file mode 100644 index 00000000..ccd977e6 --- /dev/null +++ b/api/tacticalrmm/tacticalrmm/util_settings.py @@ -0,0 +1,23 @@ +# this file must not import anything from django settings to avoid circular import issues + +from urllib.parse import urlparse + +import tldextract + + +def get_webdomain(url: str) -> str: + return urlparse(url).netloc + + +def get_root_domain(subdomain) -> str: + no_fetch_extract = tldextract.TLDExtract(suffix_list_urls=()) + extracted = no_fetch_extract(subdomain) + return f"{extracted.domain}.{extracted.suffix}" + + +def get_backend_url(subdomain, proto, port) -> str: + url = f"{proto}://{subdomain}" + if port: + url = f"{url}:{port}" + + return url