Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add convenience method for authenticating with astra via db_id and token #1228

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
90 changes: 90 additions & 0 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
import time
from threading import Lock, RLock, Thread, Event
import uuid
import os
import urllib.request
import json
from typing import Optional

import weakref
from weakref import WeakValueDictionary
Expand Down Expand Up @@ -1039,6 +1043,19 @@ def default_retry_policy(self, policy):
'use_default_tempdir': True # use the system temp dir for the zip extraction
}

(or)

{
# Astra DB cluster UUID, used if secure_connect_bundle is not provided
'db_id': 'db_id',

# required with db_id. Astra DB token
'token': 'AstraCS:change_me:change_me'

# optional with db_id & token. Astra DB region
'db_region': 'us-east1',
}

The zip file will be temporarily extracted in the same directory to
load the configuration and certificates.
"""
Expand Down Expand Up @@ -1169,6 +1186,30 @@ def __init__(self,

uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection)
uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection)

# Check if we need to download the secure connect bundle
if all(akey in cloud for akey in ['db_id', 'token']):
# download SCB if necessary
if 'secure_connect_bundle' not in cloud:
bundle_path = f'astradb-scb-{cloud["db_id"]}'
if 'db_region' in cloud:
bundle_path += f'-{cloud["db_region"]}.zip'
else:
bundle_path += '.zip'
if not os.path.exists(bundle_path):
log.info('Downloading Secure Cloud Bundle...')
url = self._get_astra_bundle_url(cloud['db_id'], cloud['token'], cloud.get("db_region"))
try:
with urllib.request.urlopen(url) as r:
with open(bundle_path, 'wb') as f:
f.write(r.read())
except urllib.error.URLError as e:
raise Exception(f"Error downloading secure connect bundle: {str(e)}")
cloud['secure_connect_bundle'] = bundle_path
# Set up auth_provider if not provided
if auth_provider is None:
auth_provider = PlainTextAuthProvider('token', cloud['token'])

cloud_config = dscloud.get_cloud_config(cloud, create_pyopenssl_context=uses_twisted or uses_eventlet)

ssl_context = cloud_config.ssl_context
Expand Down Expand Up @@ -2184,6 +2225,55 @@ def get_control_connection_host(self):
endpoint = connection.endpoint if connection else None
return self.metadata.get_host(endpoint) if endpoint else None

@staticmethod
def _get_astra_bundle_url(db_id, token, db_region: Optional[str] = None):
"""
Retrieves the secure connect bundle URL for an Astra DB cluster based on the provided 'db_id',
'db_region' (optional) and 'token'.

Args:
db_id (str): The Astra DB cluster UUID.
token (str): The Astra security token.
db_region (optional str): The Astra DB cluster region.

Returns:
str: The secure connect bundle URL for the given inputs.

Raises:
ValueError: If the Astra DB API response is missing the download url or the specified db_region is not found.
Exception: If there's an error connecting to the Astra API or processing the response.
"""
url = f"https://api.astra.datastax.com/v2/databases/{db_id}/datacenters"
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}

req = urllib.request.Request(url, method="GET", headers=headers, data=b"")
try:
with urllib.request.urlopen(req) as response:
response_data = json.loads(response.read().decode())

# Convert list of responses to a dict keyed by region
datacenter_dict = {dc['region']: dc for dc in response_data if 'region' in dc}

# Pull out the specified region, or the first one if not specified
if not datacenter_dict:
raise ValueError("No valid datacenter information found in the Astra DB API response")
if db_region:
if db_region not in datacenter_dict:
raise ValueError(f"Astra DB region '{db_region}' not found in list of regions")
datacenter = datacenter_dict[db_region]
else:
# Use the first datacenter as the primary region
datacenter = next(iter(datacenter_dict.values()))

if 'secureBundleUrl' not in datacenter or not datacenter['secureBundleUrl']:
raise ValueError("'secureBundleUrl' is missing or empty in the Astra DB API response")
return datacenter['secureBundleUrl']
except urllib.error.URLError as e:
raise Exception(f"Error connecting to Astra API: {str(e)}")

def refresh_schema_metadata(self, max_schema_agreement_wait=None):
"""
Synchronously refresh all schema metadata.
Expand Down