diff --git a/cassandra/cluster.py b/cassandra/cluster.py index d5f80290a..4fd99a004 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -37,6 +37,10 @@ import time from threading import Lock, RLock, Thread, Event import uuid +import os +import urllib.request +import json +from typing import Optional import weakref from weakref import WeakValueDictionary @@ -1039,6 +1043,19 @@ def default_retry_policy(self, policy): 'use_default_tempdir': True # use the system temp dir for the zip extraction } + (or) + + { + # Astra DB cluster UUID, used if secure_connect_bundle is not provided + 'db_id': 'db_id', + + # required with db_id. Astra DB token + 'token': 'AstraCS:change_me:change_me' + + # optional with db_id & token. Astra DB region + 'db_region': 'us-east1', + } + The zip file will be temporarily extracted in the same directory to load the configuration and certificates. """ @@ -1169,6 +1186,30 @@ def __init__(self, uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection) uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection) + + # Check if we need to download the secure connect bundle + if all(akey in cloud for akey in ['db_id', 'token']): + # download SCB if necessary + if 'secure_connect_bundle' not in cloud: + bundle_path = f'astradb-scb-{cloud["db_id"]}' + if 'db_region' in cloud: + bundle_path += f'-{cloud["db_region"]}.zip' + else: + bundle_path += '.zip' + if not os.path.exists(bundle_path): + log.info('Downloading Secure Cloud Bundle...') + url = self._get_astra_bundle_url(cloud['db_id'], cloud['token'], cloud.get("db_region")) + try: + with urllib.request.urlopen(url) as r: + with open(bundle_path, 'wb') as f: + f.write(r.read()) + except urllib.error.URLError as e: + raise Exception(f"Error downloading secure connect bundle: {str(e)}") + cloud['secure_connect_bundle'] = bundle_path + # Set up auth_provider if not provided + if auth_provider is None: + auth_provider = PlainTextAuthProvider('token', cloud['token']) + cloud_config = dscloud.get_cloud_config(cloud, create_pyopenssl_context=uses_twisted or uses_eventlet) ssl_context = cloud_config.ssl_context @@ -2184,6 +2225,55 @@ def get_control_connection_host(self): endpoint = connection.endpoint if connection else None return self.metadata.get_host(endpoint) if endpoint else None + @staticmethod + def _get_astra_bundle_url(db_id, token, db_region: Optional[str] = None): + """ + Retrieves the secure connect bundle URL for an Astra DB cluster based on the provided 'db_id', + 'db_region' (optional) and 'token'. + + Args: + db_id (str): The Astra DB cluster UUID. + token (str): The Astra security token. + db_region (optional str): The Astra DB cluster region. + + Returns: + str: The secure connect bundle URL for the given inputs. + + Raises: + ValueError: If the Astra DB API response is missing the download url or the specified db_region is not found. + Exception: If there's an error connecting to the Astra API or processing the response. + """ + url = f"https://api.astra.datastax.com/v2/databases/{db_id}/datacenters" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + req = urllib.request.Request(url, method="GET", headers=headers, data=b"") + try: + with urllib.request.urlopen(req) as response: + response_data = json.loads(response.read().decode()) + + # Convert list of responses to a dict keyed by region + datacenter_dict = {dc['region']: dc for dc in response_data if 'region' in dc} + + # Pull out the specified region, or the first one if not specified + if not datacenter_dict: + raise ValueError("No valid datacenter information found in the Astra DB API response") + if db_region: + if db_region not in datacenter_dict: + raise ValueError(f"Astra DB region '{db_region}' not found in list of regions") + datacenter = datacenter_dict[db_region] + else: + # Use the first datacenter as the primary region + datacenter = next(iter(datacenter_dict.values())) + + if 'secureBundleUrl' not in datacenter or not datacenter['secureBundleUrl']: + raise ValueError("'secureBundleUrl' is missing or empty in the Astra DB API response") + return datacenter['secureBundleUrl'] + except urllib.error.URLError as e: + raise Exception(f"Error connecting to Astra API: {str(e)}") + def refresh_schema_metadata(self, max_schema_agreement_wait=None): """ Synchronously refresh all schema metadata.