From 75713de847251cc362095552372cae73568c271d Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Sun, 29 Sep 2024 14:09:30 +0100 Subject: [PATCH] handle CORS --- api/app/settings/common.py | 28 +++++++------- api/core/helpers.py | 20 ++++++---- api/custom_auth/apps.py | 1 + api/custom_auth/jwt_cookie/signals.py | 20 ++++++++++ .../test_custom_auth_integration.py | 38 +++++++++++++++++++ api/tests/unit/core/test_helpers.py | 37 ++++++++++++++++-- 6 files changed, 120 insertions(+), 24 deletions(-) create mode 100644 api/custom_auth/jwt_cookie/signals.py diff --git a/api/app/settings/common.py b/api/app/settings/common.py index 5431e9f3f9ce..fd1815c9dff4 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -418,19 +418,6 @@ MEDIA_URL = "/media/" # unused but needs to be different from STATIC_URL in django 3 -# CORS settings - -CORS_ORIGIN_ALLOW_ALL = True -FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS = env.list( - "FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS", default=["sentry-trace"] -) -CORS_ALLOW_HEADERS = [ - *default_headers, - *FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS, - "X-Environment-Key", - "X-E2E-Test-Auth-Token", -] - DEFAULT_FROM_EMAIL = env("SENDER_EMAIL", default="noreply@flagsmith.com") EMAIL_CONFIGURATION = { # Invitations with name is anticipated to take two arguments. The persons name and the @@ -1046,6 +1033,21 @@ USE_SECURE_COOKIES = env.bool("USE_SECURE_COOKIES", default=True) COOKIE_SAME_SITE = env.str("COOKIE_SAME_SITE", default="none") +# CORS settings + +CORS_ORIGIN_ALLOW_ALL = env.bool("CORS_ORIGIN_ALLOW_ALL", not COOKIE_AUTH_ENABLED) +CORS_ALLOW_CREDENTIALS = env.bool("CORS_ALLOW_CREDENTIALS", COOKIE_AUTH_ENABLED) +FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS = env.list( + "FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS", default=["sentry-trace"] +) +CORS_ALLOWED_ORIGINS = env.list("CORS_ALLOWED_ORIGINS", default=[]) +CORS_ALLOW_HEADERS = [ + *default_headers, + *FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS, + "X-Environment-Key", + "X-E2E-Test-Auth-Token", +] + # use a separate boolean setting so that we add it to the API containers in environments # where we're running the task processor, so we avoid creating unnecessary tasks ENABLE_PIPEDRIVE_LEAD_TRACKING = env.bool("ENABLE_PIPEDRIVE_LEAD_TRACKING", False) diff --git a/api/core/helpers.py b/api/core/helpers.py index 3af9a067664d..4f327e5a8556 100644 --- a/api/core/helpers.py +++ b/api/core/helpers.py @@ -1,7 +1,7 @@ import re from django.conf import settings -from django.contrib.sites.models import Site +from django.contrib.sites import models as sites_models from django.http import HttpRequest from rest_framework.request import Request @@ -11,12 +11,18 @@ def get_current_site_url(request: HttpRequest | Request | None = None) -> str: - if settings.DOMAIN_OVERRIDE: - domain = settings.DOMAIN_OVERRIDE - elif current_site := Site.objects.filter(id=settings.SITE_ID).first(): - domain = current_site.domain - else: - domain = settings.DEFAULT_DOMAIN + if not (domain := settings.DOMAIN_OVERRIDE): + try: + domain = sites_models.Site.objects.get_current(request).domain + except sites_models.Site.DoesNotExist: + # For the rare case when `DOMAIN_OVERRIDE` was not set and no `Site` object present, + # store a default domain `Site` in the sites cache + # so it's correctly invalidated should the user decide to create own `Site` object. + domain = settings.DEFAULT_DOMAIN + sites_models.SITE_CACHE[settings.SITE_ID] = sites_models.Site( + name="Flagsmith", + domain=domain, + ) if request: scheme = request.scheme diff --git a/api/custom_auth/apps.py b/api/custom_auth/apps.py index 2328a449d114..005927503540 100644 --- a/api/custom_auth/apps.py +++ b/api/custom_auth/apps.py @@ -6,3 +6,4 @@ class CustomAuthAppConfig(AppConfig): def ready(self) -> None: from custom_auth import tasks # noqa F401 + from custom_auth.jwt_cookie import signals # noqa F401 diff --git a/api/custom_auth/jwt_cookie/signals.py b/api/custom_auth/jwt_cookie/signals.py new file mode 100644 index 000000000000..1e840eb681d7 --- /dev/null +++ b/api/custom_auth/jwt_cookie/signals.py @@ -0,0 +1,20 @@ +from typing import Any +from urllib.parse import urlparse + +from core.helpers import get_current_site_url +from corsheaders.signals import check_request_enabled +from django.dispatch import receiver +from django.http import HttpRequest + + +@receiver(check_request_enabled) +def cors_allow_current_site(request: HttpRequest, **kwargs: Any) -> bool: + # The signal is expected to only be dispatched: + # - When `settings.CORS_ORIGIN_ALLOW_ALL` is set to `False`. + # - For requests with `HTTP_ORIGIN` set. + origin_url = urlparse(request.META["HTTP_ORIGIN"]) + current_site_url = urlparse(get_current_site_url(request)) + return ( + origin_url.scheme == current_site_url.scheme + and origin_url.netloc == current_site_url.netloc + ) diff --git a/api/tests/integration/custom_auth/end_to_end/test_custom_auth_integration.py b/api/tests/integration/custom_auth/end_to_end/test_custom_auth_integration.py index 9204b367e1dd..9e1943fcbe44 100644 --- a/api/tests/integration/custom_auth/end_to_end/test_custom_auth_integration.py +++ b/api/tests/integration/custom_auth/end_to_end/test_custom_auth_integration.py @@ -408,6 +408,44 @@ def test_login_workflow__jwt_cookie__mfa_enabled( assert not response.data +# In the real world, setting `COOKIE_AUTH_ENABLED` to `True` +# changes default CORS setting values. +# Due to how Django settings are loaded for tests, +# we have to override CORS settings manually. +@override_settings( + COOKIE_AUTH_ENABLED=True, + DOMAIN_OVERRIDE="testhost.com", + CORS_ORIGIN_ALLOW_ALL=False, + CORS_ALLOW_CREDENTIALS=True, +) +def test_login_workflow__jwt_cookie__cors_headers_expected( + db: None, + api_client: APIClient, +) -> None: + # Given + email = "test@example.com" + password = FFAdminUser.objects.make_random_password() + register_url = reverse("api-v1:custom_auth:ffadminuser-list") + protected_resource_url = reverse("api-v1:projects:project-list") + register_data = { + "first_name": "test", + "last_name": "last_name", + "email": email, + "password": password, + "re_password": password, + } + api_client.post(register_url, data=register_data) + + # When + response = api_client.get( + protected_resource_url, + HTTP_ORIGIN="http://testhost.com", + ) + + # Then + assert response.headers["Access-Control-Allow-Origin"] == "http://testhost.com" + + def test_throttle_login_workflows( api_client: APIClient, db: None, diff --git a/api/tests/unit/core/test_helpers.py b/api/tests/unit/core/test_helpers.py index 03e3455a2714..e57483bd7045 100644 --- a/api/tests/unit/core/test_helpers.py +++ b/api/tests/unit/core/test_helpers.py @@ -5,7 +5,7 @@ from django.contrib.sites.models import Site if typing.TYPE_CHECKING: - from pytest_django.fixtures import SettingsWrapper + from pytest_django.fixtures import DjangoAssertNumQueries, SettingsWrapper from pytest_mock import MockerFixture pytestmark = pytest.mark.django_db @@ -26,13 +26,13 @@ def test_get_current_site_url_returns_correct_url_if_site_exists( assert url == f"https://{expected_domain}" -def test_get_current_site_url_uses_default_url_if_site_does_not_exists( +def test_get_current_site_url_uses_default_url_if_site_does_not_exist( settings: "SettingsWrapper", ) -> None: # Given expected_domain = "some-testing-url.com" settings.DEFAULT_DOMAIN = expected_domain - settings.SITE_ID = None + Site.objects.all().delete() # When url = get_current_site_url() @@ -41,6 +41,35 @@ def test_get_current_site_url_uses_default_url_if_site_does_not_exists( assert url == f"https://{expected_domain}" +def test_get_current_site_url__site_created__cached_return_expected( + settings: "SettingsWrapper", + django_assert_num_queries: "DjangoAssertNumQueries", +) -> None: + # Given + expected_domain_without_site = "some-new-testing-url.com" + expected_domain_with_site = "some-testing-url.com" + settings.DEFAULT_DOMAIN = expected_domain_without_site + Site.objects.all().delete() + + # When + with django_assert_num_queries(1): + get_current_site_url() + url_without_site = get_current_site_url() + + settings.SITE_ID = Site.objects.create( + name="test_site", + domain=expected_domain_with_site, + ).id + + with django_assert_num_queries(1): + get_current_site_url() + url_with_site = get_current_site_url() + + # Then + assert url_without_site == f"https://{expected_domain_without_site}" + assert url_with_site == f"https://{expected_domain_with_site}" + + def test_get_current_site__domain_override__with_site__return_expected( settings: "SettingsWrapper", ) -> None: @@ -62,7 +91,7 @@ def test_get_current_site__domain_override__no_site__return_expected( settings: "SettingsWrapper", ) -> None: # Given - settings.SITE_ID = None + Site.objects.all().delete() expected_domain = "some-testing-url.com" settings.DOMAIN_OVERRIDE = expected_domain