diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 1399c0a1d..48abdf482 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -13,6 +13,7 @@ from .postgres.user_pass import PostgresUserPasswordProfileMapping from .redshift.user_pass import RedshiftUserPasswordProfileMapping from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping +from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping from .spark.thrift import SparkThriftProfileMapping from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping @@ -24,6 +25,7 @@ PostgresUserPasswordProfileMapping, RedshiftUserPasswordProfileMapping, SnowflakeUserPasswordProfileMapping, + SnowflakePrivateKeyPemProfileMapping, SparkThriftProfileMapping, ExasolUserPasswordProfileMapping, TrinoLDAPProfileMapping, diff --git a/cosmos/profiles/snowflake/__init__.py b/cosmos/profiles/snowflake/__init__.py index 2433133ed..450dc3772 100644 --- a/cosmos/profiles/snowflake/__init__.py +++ b/cosmos/profiles/snowflake/__init__.py @@ -1,5 +1,6 @@ "Snowflake Airflow connection -> dbt profile mapping." from .user_pass import SnowflakeUserPasswordProfileMapping +from .user_privatekey import SnowflakePrivateKeyPemProfileMapping -__all__ = ["SnowflakeUserPasswordProfileMapping"] +__all__ = ["SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping"] diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py new file mode 100644 index 000000000..451c3f953 --- /dev/null +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -0,0 +1,87 @@ +"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from ..base import BaseProfileMapping + +if TYPE_CHECKING: + from airflow.models import Connection + + +class SnowflakePrivateKeyPemProfileMapping(BaseProfileMapping): + """ + Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. + https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication + https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html + """ + + airflow_connection_type: str = "snowflake" + is_community: bool = True + + required_fields = [ + "account", + "user", + "database", + "warehouse", + "schema", + "private_key_content", + ] + secret_fields = [ + "private_key_content", + ] + airflow_param_mapping = { + "account": "extra.account", + "user": "login", + "database": "extra.database", + "warehouse": "extra.warehouse", + "schema": "schema", + "role": "extra.role", + "private_key_content": "extra.private_key_content", + } + + def __init__(self, conn: Connection, profile_args: dict[str, Any | None] | None = None) -> None: + """ + Snowflake can be odd because the fields used to be stored with keys in the format + 'extra__snowflake__account', but now are stored as 'account'. + + This standardizes the keys to be 'account', 'database', etc. + """ + conn_dejson = conn.extra_dejson + + if conn_dejson.get("extra__snowflake__account"): + conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} + + conn.extra = json.dumps(conn_dejson) + + self.conn = conn + self.profile_args = profile_args or {} + super().__init__(conn, profile_args) + + @property + def profile(self) -> dict[str, Any | None]: + "Gets profile." + profile_vars = { + "type": "snowflake", + "account": self.account, + "user": self.user, + "schema": self.schema, + "database": self.database, + "role": self.conn.extra_dejson.get("role"), + "warehouse": self.conn.extra_dejson.get("warehouse"), + **self.profile_args, + # private_key should always get set as env var + "private_key_content": self.get_env_var_format("private_key_content"), + } + + # remove any null values + return self.filter_null(profile_vars) + + def transform_account(self, account: str) -> str: + "Transform the account to the format . if it's not already." + region = self.conn.extra_dejson.get("region") + if region and region not in account: + account = f"{account}.{region}" + + return str(account) diff --git a/docs/dbt/connections-profiles.rst b/docs/dbt/connections-profiles.rst index 59187b9d6..7b62e9bdc 100644 --- a/docs/dbt/connections-profiles.rst +++ b/docs/dbt/connections-profiles.rst @@ -144,6 +144,14 @@ Username and Password :members: +Username and Private Key +~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: cosmos.profiles.snowflake.SnowflakePrivateKeyPemProfileMapping + :undoc-members: + :members: + + Spark ----- diff --git a/tests/profiles/snowflake/test_snowflake_user_pass.py b/tests/profiles/snowflake/test_snowflake_user_pass.py index 17bcffb43..6efbb8f86 100644 --- a/tests/profiles/snowflake/test_snowflake_user_pass.py +++ b/tests/profiles/snowflake/test_snowflake_user_pass.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake profile." +"Tests for the Snowflake user/password profile." import json from unittest.mock import patch @@ -18,7 +18,7 @@ def mock_snowflake_conn(): # type: ignore Sets the connection as an environment variable. """ conn = Connection( - conn_id="my_snowflake_connection", + conn_id="my_snowflake_password_connection", conn_type="snowflake", login="my_user", password="my_password", diff --git a/tests/profiles/snowflake/test_snowflake_user_privatekey.py b/tests/profiles/snowflake/test_snowflake_user_privatekey.py new file mode 100644 index 000000000..5b6983505 --- /dev/null +++ b/tests/profiles/snowflake/test_snowflake_user_privatekey.py @@ -0,0 +1,238 @@ +"Tests for the Snowflake user/private key profile." + +import json +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_profile_mapping +from cosmos.profiles.snowflake import ( + SnowflakePrivateKeyPemProfileMapping, +) + + +@pytest.fixture() +def mock_snowflake_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_snowflake_pk_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + extra=json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": "my_private_key", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the Snowflake profile mapping claims the correct connection type. + """ + # should only claim when: + # - conn_type == snowflake + # and the following exist: + # - user + # - private key + # - account + # - database + # - warehouse + # - schema + potential_values = { + "conn_type": "snowflake", + "login": "my_user", + "schema": "my_database", + "extra": json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": "my_private_key", + } + ), + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + print("testing with", values) + + profile_mapping = SnowflakePrivateKeyPemProfileMapping( + conn, + ) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the account + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the database + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the warehouse + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected. + """ + profile_mapping = get_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert isinstance(profile_mapping, SnowflakePrivateKeyPemProfileMapping) + + +def test_profile_args( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the profile values get set correctly. + """ + profile_mapping = get_profile_mapping( + mock_snowflake_conn.conn_id, + ) + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}", + "schema": mock_snowflake_conn.schema, + "account": mock_snowflake_conn.extra_dejson.get("account"), + "database": mock_snowflake_conn.extra_dejson.get("database"), + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_args_overrides( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that you can override the profile values. + """ + profile_mapping = get_profile_mapping( + mock_snowflake_conn.conn_id, + profile_args={"database": "my_db_override"}, + ) + assert profile_mapping.profile_args == { + "database": "my_db_override", + } + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}", + "schema": mock_snowflake_conn.schema, + "account": mock_snowflake_conn.extra_dejson.get("account"), + "database": "my_db_override", + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_env_vars( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the environment variables get set correctly. + """ + profile_mapping = get_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT": mock_snowflake_conn.extra_dejson.get("private_key_content"), + } + + +def test_old_snowflake_format() -> None: + """ + Tests that the old format still works. + """ + conn = Connection( + conn_id="my_snowflake_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + extra=json.dumps( + { + "extra__snowflake__account": "my_account", + "extra__snowflake__database": "my_database", + "extra__snowflake__warehouse": "my_warehouse", + "extra__snowflake__private_key_content": "my_private_key", + } + ), + ) + + profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn) + assert profile_mapping.profile == { + "type": conn.conn_type, + "user": conn.login, + "private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}", + "schema": conn.schema, + "account": conn.extra_dejson.get("account"), + "database": conn.extra_dejson.get("database"), + "warehouse": conn.extra_dejson.get("warehouse"), + } + + +def test_appends_region() -> None: + """ + Tests that region is appended to account if it doesn't already exist. + """ + conn = Connection( + conn_id="my_snowflake_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + extra=json.dumps( + { + "account": "my_account", + "region": "my_region", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": "my_private_key", + } + ), + ) + + profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn) + assert profile_mapping.profile == { + "type": conn.conn_type, + "user": conn.login, + "private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}", + "schema": conn.schema, + "account": f"{conn.extra_dejson.get('account')}.{conn.extra_dejson.get('region')}", + "database": conn.extra_dejson.get("database"), + "warehouse": conn.extra_dejson.get("warehouse"), + }