Skip to content

Commit

Permalink
Add profile to support Snowflake connection with private key (#378)
Browse files Browse the repository at this point in the history
Added new profile mapping for Snowflake User/Private Key authentication. Currently only supports private_key_content from a Snowflake Airflow connection.

Closes #267
  • Loading branch information
patawan authored Jul 21, 2023
1 parent f52293e commit 8a40537
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 3 deletions.
2 changes: 2 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .postgres.user_pass import PostgresUserPasswordProfileMapping
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
Expand All @@ -24,6 +25,7 @@
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
SnowflakePrivateKeyPemProfileMapping,
SparkThriftProfileMapping,
ExasolUserPasswordProfileMapping,
TrinoLDAPProfileMapping,
Expand Down
3 changes: 2 additions & 1 deletion cosmos/profiles/snowflake/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"Snowflake Airflow connection -> dbt profile mapping."

from .user_pass import SnowflakeUserPasswordProfileMapping
from .user_privatekey import SnowflakePrivateKeyPemProfileMapping

__all__ = ["SnowflakeUserPasswordProfileMapping"]
__all__ = ["SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping"]
87 changes: 87 additions & 0 deletions cosmos/profiles/snowflake/user_privatekey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakePrivateKeyPemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""

airflow_connection_type: str = "snowflake"
is_community: bool = True

required_fields = [
"account",
"user",
"database",
"warehouse",
"schema",
"private_key_content",
]
secret_fields = [
"private_key_content",
]
airflow_param_mapping = {
"account": "extra.account",
"user": "login",
"database": "extra.database",
"warehouse": "extra.warehouse",
"schema": "schema",
"role": "extra.role",
"private_key_content": "extra.private_key_content",
}

def __init__(self, conn: Connection, profile_args: dict[str, Any | None] | None = None) -> None:
"""
Snowflake can be odd because the fields used to be stored with keys in the format
'extra__snowflake__account', but now are stored as 'account'.
This standardizes the keys to be 'account', 'database', etc.
"""
conn_dejson = conn.extra_dejson

if conn_dejson.get("extra__snowflake__account"):
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()}

conn.extra = json.dumps(conn_dejson)

self.conn = conn
self.profile_args = profile_args or {}
super().__init__(conn, profile_args)

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
"type": "snowflake",
"account": self.account,
"user": self.user,
"schema": self.schema,
"database": self.database,
"role": self.conn.extra_dejson.get("role"),
"warehouse": self.conn.extra_dejson.get("warehouse"),
**self.profile_args,
# private_key should always get set as env var
"private_key_content": self.get_env_var_format("private_key_content"),
}

# remove any null values
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"Transform the account to the format <account>.<region> if it's not already."
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
8 changes: 8 additions & 0 deletions docs/dbt/connections-profiles.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ Username and Password
:members:


Username and Private Key
~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: cosmos.profiles.snowflake.SnowflakePrivateKeyPemProfileMapping
:undoc-members:
:members:


Spark
-----

Expand Down
4 changes: 2 additions & 2 deletions tests/profiles/snowflake/test_snowflake_user_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Tests for the Snowflake profile."
"Tests for the Snowflake user/password profile."

import json
from unittest.mock import patch
Expand All @@ -18,7 +18,7 @@ def mock_snowflake_conn(): # type: ignore
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_snowflake_connection",
conn_id="my_snowflake_password_connection",
conn_type="snowflake",
login="my_user",
password="my_password",
Expand Down
238 changes: 238 additions & 0 deletions tests/profiles/snowflake/test_snowflake_user_privatekey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
"Tests for the Snowflake user/private key profile."

import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_profile_mapping
from cosmos.profiles.snowflake import (
SnowflakePrivateKeyPemProfileMapping,
)


@pytest.fixture()
def mock_snowflake_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_snowflake_pk_connection",
conn_type="snowflake",
login="my_user",
schema="my_schema",
extra=json.dumps(
{
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_content": "my_private_key",
}
),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the Snowflake profile mapping claims the correct connection type.
"""
# should only claim when:
# - conn_type == snowflake
# and the following exist:
# - user
# - private key
# - account
# - database
# - warehouse
# - schema
potential_values = {
"conn_type": "snowflake",
"login": "my_user",
"schema": "my_database",
"extra": json.dumps(
{
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_content": "my_private_key",
}
),
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

print("testing with", values)

profile_mapping = SnowflakePrivateKeyPemProfileMapping(
conn,
)
assert not profile_mapping.can_claim_connection()

# test when we're missing the account
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# test when we're missing the database
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# test when we're missing the warehouse
conn = Connection(**potential_values) # type: ignore
conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}'
print("testing with", conn.extra)
profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert profile_mapping.can_claim_connection()


def test_profile_mapping_selected(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected.
"""
profile_mapping = get_profile_mapping(
mock_snowflake_conn.conn_id,
)
assert isinstance(profile_mapping, SnowflakePrivateKeyPemProfileMapping)


def test_profile_args(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the profile values get set correctly.
"""
profile_mapping = get_profile_mapping(
mock_snowflake_conn.conn_id,
)

assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}",
"schema": mock_snowflake_conn.schema,
"account": mock_snowflake_conn.extra_dejson.get("account"),
"database": mock_snowflake_conn.extra_dejson.get("database"),
"warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"),
}


def test_profile_args_overrides(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that you can override the profile values.
"""
profile_mapping = get_profile_mapping(
mock_snowflake_conn.conn_id,
profile_args={"database": "my_db_override"},
)
assert profile_mapping.profile_args == {
"database": "my_db_override",
}

assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}",
"schema": mock_snowflake_conn.schema,
"account": mock_snowflake_conn.extra_dejson.get("account"),
"database": "my_db_override",
"warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"),
}


def test_profile_env_vars(
mock_snowflake_conn: Connection,
) -> None:
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = get_profile_mapping(
mock_snowflake_conn.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT": mock_snowflake_conn.extra_dejson.get("private_key_content"),
}


def test_old_snowflake_format() -> None:
"""
Tests that the old format still works.
"""
conn = Connection(
conn_id="my_snowflake_connection",
conn_type="snowflake",
login="my_user",
schema="my_schema",
extra=json.dumps(
{
"extra__snowflake__account": "my_account",
"extra__snowflake__database": "my_database",
"extra__snowflake__warehouse": "my_warehouse",
"extra__snowflake__private_key_content": "my_private_key",
}
),
)

profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert profile_mapping.profile == {
"type": conn.conn_type,
"user": conn.login,
"private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}",
"schema": conn.schema,
"account": conn.extra_dejson.get("account"),
"database": conn.extra_dejson.get("database"),
"warehouse": conn.extra_dejson.get("warehouse"),
}


def test_appends_region() -> None:
"""
Tests that region is appended to account if it doesn't already exist.
"""
conn = Connection(
conn_id="my_snowflake_connection",
conn_type="snowflake",
login="my_user",
schema="my_schema",
extra=json.dumps(
{
"account": "my_account",
"region": "my_region",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_content": "my_private_key",
}
),
)

profile_mapping = SnowflakePrivateKeyPemProfileMapping(conn)
assert profile_mapping.profile == {
"type": conn.conn_type,
"user": conn.login,
"private_key_content": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_CONTENT') }}",
"schema": conn.schema,
"account": f"{conn.extra_dejson.get('account')}.{conn.extra_dejson.get('region')}",
"database": conn.extra_dejson.get("database"),
"warehouse": conn.extra_dejson.get("warehouse"),
}

0 comments on commit 8a40537

Please sign in to comment.