diff --git a/daft/delta_lake/delta_lake_scan.py b/daft/delta_lake/delta_lake_scan.py index 56bc60bf55..a5c7b94b29 100644 --- a/daft/delta_lake/delta_lake_scan.py +++ b/daft/delta_lake/delta_lake_scan.py @@ -3,6 +3,7 @@ import logging import os from typing import TYPE_CHECKING +from urllib.parse import urlparse from deltalake.table import DeltaTable @@ -40,31 +41,39 @@ def __init__( # # See: https://github.com/delta-io/delta-rs/issues/2117 deltalake_sdk_io_config = storage_config.config.io_config - if any([deltalake_sdk_io_config.s3.key_id is None, deltalake_sdk_io_config.s3.region_name is None]): - try: - s3_config_from_env = S3Config.from_env() - # Sometimes S3Config.from_env throws an error, for example on CI machines with weird metadata servers. - except daft.exceptions.DaftCoreException: - pass - else: - if ( - deltalake_sdk_io_config.s3.key_id is None - and deltalake_sdk_io_config.s3.access_key is None - and deltalake_sdk_io_config.s3.session_token is None - ): - deltalake_sdk_io_config = deltalake_sdk_io_config.replace( - s3=deltalake_sdk_io_config.s3.replace( - key_id=s3_config_from_env.key_id, - access_key=s3_config_from_env.access_key, - session_token=s3_config_from_env.session_token, + scheme = urlparse(table_uri).scheme + if scheme == "s3" or scheme == "s3a": + if any([deltalake_sdk_io_config.s3.key_id is None, deltalake_sdk_io_config.s3.region_name is None]): + try: + s3_config_from_env = S3Config.from_env() + # Sometimes S3Config.from_env throws an error, for example on CI machines with weird metadata servers. + except daft.exceptions.DaftCoreException: + pass + else: + if ( + deltalake_sdk_io_config.s3.key_id is None + and deltalake_sdk_io_config.s3.access_key is None + and deltalake_sdk_io_config.s3.session_token is None + ): + deltalake_sdk_io_config = deltalake_sdk_io_config.replace( + s3=deltalake_sdk_io_config.s3.replace( + key_id=s3_config_from_env.key_id, + access_key=s3_config_from_env.access_key, + session_token=s3_config_from_env.session_token, + ) ) - ) - if deltalake_sdk_io_config.s3.region_name is None: - deltalake_sdk_io_config = deltalake_sdk_io_config.replace( - s3=deltalake_sdk_io_config.s3.replace( - region_name=s3_config_from_env.region_name, + if deltalake_sdk_io_config.s3.region_name is None: + deltalake_sdk_io_config = deltalake_sdk_io_config.replace( + s3=deltalake_sdk_io_config.s3.replace( + region_name=s3_config_from_env.region_name, + ) ) - ) + elif scheme == "gcs" or scheme == "gs": + # TO-DO: Handle any key-value replacements in `io_config` if there are missing elements + pass + elif scheme == "az" or scheme == "abfs" or scheme == "abfss": + # TO-DO: Handle any key-value replacements in `io_config` if there are missing elements + pass self._table = DeltaTable( table_uri, storage_options=io_config_to_storage_options(deltalake_sdk_io_config, table_uri) diff --git a/daft/unity_catalog/unity_catalog.py b/daft/unity_catalog/unity_catalog.py index 9eafaee8dd..640e557627 100644 --- a/daft/unity_catalog/unity_catalog.py +++ b/daft/unity_catalog/unity_catalog.py @@ -2,10 +2,11 @@ import dataclasses from typing import Callable +from urllib.parse import urlparse import unitycatalog -from daft.io import IOConfig, S3Config +from daft.io import AzureConfig, IOConfig, S3Config @dataclasses.dataclass(frozen=True) @@ -96,18 +97,28 @@ def load_table(self, table_name: str) -> UnityCatalogTable: # Grab credentials from Unity catalog and place it into the Table temp_table_credentials = self._client.temporary_table_credentials.create(operation="READ", table_id=table_id) - aws_temp_credentials = temp_table_credentials.aws_temp_credentials - io_config = ( - IOConfig( - s3=S3Config( - key_id=aws_temp_credentials.access_key_id, - access_key=aws_temp_credentials.secret_access_key, - session_token=aws_temp_credentials.session_token, + + scheme = urlparse(storage_location).scheme + if scheme == "s3" or scheme == "s3a": + aws_temp_credentials = temp_table_credentials.aws_temp_credentials + io_config = ( + IOConfig( + s3=S3Config( + key_id=aws_temp_credentials.access_key_id, + access_key=aws_temp_credentials.secret_access_key, + session_token=aws_temp_credentials.session_token, + ) ) + if aws_temp_credentials is not None + else None + ) + elif scheme == "gcs" or scheme == "gs": + # TO-DO: gather GCS credential vending assets from Unity and construct 'io_config`` + pass + elif scheme == "az" or scheme == "abfs" or scheme == "abfss": + io_config = IOConfig( + azure=AzureConfig(sas_token=temp_table_credentials.azure_user_delegation_sas.get("sas_token")) ) - if aws_temp_credentials is not None - else None - ) return UnityCatalogTable( table_uri=storage_location,