Skip to content

Commit

Permalink
handle CORS
Browse files Browse the repository at this point in the history
  • Loading branch information
khvn26 committed Sep 29, 2024
1 parent 69b8817 commit 75713de
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 24 deletions.
28 changes: 15 additions & 13 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,19 +418,6 @@

MEDIA_URL = "/media/" # unused but needs to be different from STATIC_URL in django 3

# CORS settings

CORS_ORIGIN_ALLOW_ALL = True
FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS = env.list(
"FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS", default=["sentry-trace"]
)
CORS_ALLOW_HEADERS = [
*default_headers,
*FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS,
"X-Environment-Key",
"X-E2E-Test-Auth-Token",
]

DEFAULT_FROM_EMAIL = env("SENDER_EMAIL", default="[email protected]")
EMAIL_CONFIGURATION = {
# Invitations with name is anticipated to take two arguments. The persons name and the
Expand Down Expand Up @@ -1046,6 +1033,21 @@
USE_SECURE_COOKIES = env.bool("USE_SECURE_COOKIES", default=True)
COOKIE_SAME_SITE = env.str("COOKIE_SAME_SITE", default="none")

# CORS settings

CORS_ORIGIN_ALLOW_ALL = env.bool("CORS_ORIGIN_ALLOW_ALL", not COOKIE_AUTH_ENABLED)
CORS_ALLOW_CREDENTIALS = env.bool("CORS_ALLOW_CREDENTIALS", COOKIE_AUTH_ENABLED)
FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS = env.list(
"FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS", default=["sentry-trace"]
)
CORS_ALLOWED_ORIGINS = env.list("CORS_ALLOWED_ORIGINS", default=[])
CORS_ALLOW_HEADERS = [
*default_headers,
*FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS,
"X-Environment-Key",
"X-E2E-Test-Auth-Token",
]

# use a separate boolean setting so that we add it to the API containers in environments
# where we're running the task processor, so we avoid creating unnecessary tasks
ENABLE_PIPEDRIVE_LEAD_TRACKING = env.bool("ENABLE_PIPEDRIVE_LEAD_TRACKING", False)
Expand Down
20 changes: 13 additions & 7 deletions api/core/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re

from django.conf import settings
from django.contrib.sites.models import Site
from django.contrib.sites import models as sites_models
from django.http import HttpRequest
from rest_framework.request import Request

Expand All @@ -11,12 +11,18 @@


def get_current_site_url(request: HttpRequest | Request | None = None) -> str:
if settings.DOMAIN_OVERRIDE:
domain = settings.DOMAIN_OVERRIDE
elif current_site := Site.objects.filter(id=settings.SITE_ID).first():
domain = current_site.domain
else:
domain = settings.DEFAULT_DOMAIN
if not (domain := settings.DOMAIN_OVERRIDE):
try:
domain = sites_models.Site.objects.get_current(request).domain
except sites_models.Site.DoesNotExist:
# For the rare case when `DOMAIN_OVERRIDE` was not set and no `Site` object present,
# store a default domain `Site` in the sites cache
# so it's correctly invalidated should the user decide to create own `Site` object.
domain = settings.DEFAULT_DOMAIN
sites_models.SITE_CACHE[settings.SITE_ID] = sites_models.Site(
name="Flagsmith",
domain=domain,
)

if request:
scheme = request.scheme
Expand Down
1 change: 1 addition & 0 deletions api/custom_auth/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ class CustomAuthAppConfig(AppConfig):

def ready(self) -> None:
from custom_auth import tasks # noqa F401
from custom_auth.jwt_cookie import signals # noqa F401
20 changes: 20 additions & 0 deletions api/custom_auth/jwt_cookie/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any
from urllib.parse import urlparse

from core.helpers import get_current_site_url
from corsheaders.signals import check_request_enabled
from django.dispatch import receiver
from django.http import HttpRequest


@receiver(check_request_enabled)
def cors_allow_current_site(request: HttpRequest, **kwargs: Any) -> bool:
# The signal is expected to only be dispatched:
# - When `settings.CORS_ORIGIN_ALLOW_ALL` is set to `False`.
# - For requests with `HTTP_ORIGIN` set.
origin_url = urlparse(request.META["HTTP_ORIGIN"])
current_site_url = urlparse(get_current_site_url(request))
return (
origin_url.scheme == current_site_url.scheme
and origin_url.netloc == current_site_url.netloc
)
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,44 @@ def test_login_workflow__jwt_cookie__mfa_enabled(
assert not response.data


# In the real world, setting `COOKIE_AUTH_ENABLED` to `True`
# changes default CORS setting values.
# Due to how Django settings are loaded for tests,
# we have to override CORS settings manually.
@override_settings(
COOKIE_AUTH_ENABLED=True,
DOMAIN_OVERRIDE="testhost.com",
CORS_ORIGIN_ALLOW_ALL=False,
CORS_ALLOW_CREDENTIALS=True,
)
def test_login_workflow__jwt_cookie__cors_headers_expected(
db: None,
api_client: APIClient,
) -> None:
# Given
email = "[email protected]"
password = FFAdminUser.objects.make_random_password()
register_url = reverse("api-v1:custom_auth:ffadminuser-list")
protected_resource_url = reverse("api-v1:projects:project-list")
register_data = {
"first_name": "test",
"last_name": "last_name",
"email": email,
"password": password,
"re_password": password,
}
api_client.post(register_url, data=register_data)

# When
response = api_client.get(
protected_resource_url,
HTTP_ORIGIN="http://testhost.com",
)

# Then
assert response.headers["Access-Control-Allow-Origin"] == "http://testhost.com"


def test_throttle_login_workflows(
api_client: APIClient,
db: None,
Expand Down
37 changes: 33 additions & 4 deletions api/tests/unit/core/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.contrib.sites.models import Site

if typing.TYPE_CHECKING:
from pytest_django.fixtures import SettingsWrapper
from pytest_django.fixtures import DjangoAssertNumQueries, SettingsWrapper
from pytest_mock import MockerFixture

pytestmark = pytest.mark.django_db
Expand All @@ -26,13 +26,13 @@ def test_get_current_site_url_returns_correct_url_if_site_exists(
assert url == f"https://{expected_domain}"


def test_get_current_site_url_uses_default_url_if_site_does_not_exists(
def test_get_current_site_url_uses_default_url_if_site_does_not_exist(
settings: "SettingsWrapper",
) -> None:
# Given
expected_domain = "some-testing-url.com"
settings.DEFAULT_DOMAIN = expected_domain
settings.SITE_ID = None
Site.objects.all().delete()

# When
url = get_current_site_url()
Expand All @@ -41,6 +41,35 @@ def test_get_current_site_url_uses_default_url_if_site_does_not_exists(
assert url == f"https://{expected_domain}"


def test_get_current_site_url__site_created__cached_return_expected(
settings: "SettingsWrapper",
django_assert_num_queries: "DjangoAssertNumQueries",
) -> None:
# Given
expected_domain_without_site = "some-new-testing-url.com"
expected_domain_with_site = "some-testing-url.com"
settings.DEFAULT_DOMAIN = expected_domain_without_site
Site.objects.all().delete()

# When
with django_assert_num_queries(1):
get_current_site_url()
url_without_site = get_current_site_url()

settings.SITE_ID = Site.objects.create(
name="test_site",
domain=expected_domain_with_site,
).id

with django_assert_num_queries(1):
get_current_site_url()
url_with_site = get_current_site_url()

# Then
assert url_without_site == f"https://{expected_domain_without_site}"
assert url_with_site == f"https://{expected_domain_with_site}"


def test_get_current_site__domain_override__with_site__return_expected(
settings: "SettingsWrapper",
) -> None:
Expand All @@ -62,7 +91,7 @@ def test_get_current_site__domain_override__no_site__return_expected(
settings: "SettingsWrapper",
) -> None:
# Given
settings.SITE_ID = None
Site.objects.all().delete()

expected_domain = "some-testing-url.com"
settings.DOMAIN_OVERRIDE = expected_domain
Expand Down

0 comments on commit 75713de

Please sign in to comment.