Skip to content
This repository has been archived by the owner on Apr 11, 2024. It is now read-only.

Commit

Permalink
AWS SigV4 support in Fetch Migration (opensearch-project#394)
Browse files Browse the repository at this point in the history
This commit enables support for SigV4 signed requests to AWS OpenSearch / OpenSearch Serverless endpoints in Fetch Migration, using the requests-aws4auth library. To achieve this, this change is comprised of several parts:

1 - All “endpoint” related logic is now encapsulated in the EndpointInfo class, including logic to construct API paths instead of having callers compute this.

2 - Construction of EndpointInfo instances from the Data Prepper pipeline configuration (and its plugin configuration sub-sections) has been moved out of metadata_migration.py to a new endpoint_utils.py file for better abstraction.

3 - Use of SigV4 is inferred from the supplied DP pipeline (separately for source and sink), including detection of the serverless key to change the service name (‘es’ vs. ‘aoss’)

4 - Since AWS4Auth requires a region argument, Fetch Migration first checks the plugin configuration for an explicitly defined region. If this is not present (since it is an optional parameter), the code attempts to derive the region based on the service endpoint URL (since generated endpoint URLs usually include the region). If a region cannot be inferred, a ValueError is thrown.

Unit tests for all of these components have been added (or existing ones updated). A minor refactoring of logging in migration_monitor.py is also included, which improves unit test code coverage.

---------

Signed-off-by: Kartik Ganesh <[email protected]>
  • Loading branch information
kartg authored Nov 10, 2023
1 parent 353ab11 commit 8989f00
Show file tree
Hide file tree
Showing 11 changed files with 554 additions and 296 deletions.
5 changes: 4 additions & 1 deletion FetchMigration/python/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
coverage>=7.3.2
pur>=7.3.1
pur>=7.3.1
moto>=4.2.7
# Transitive dependency from moto, explicit version needed to mitigate CVE-2023-46136
werkzeug>=3.0.1
42 changes: 37 additions & 5 deletions FetchMigration/python/endpoint_info.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,40 @@
from dataclasses import dataclass
from typing import Optional

from requests_aws4auth import AWS4Auth

@dataclass

# Class that encapsulates endpoint information for an OpenSearch/Elasticsearch cluster
class EndpointInfo:
url: str
auth: tuple = None
verify_ssl: bool = True
# Private member variables
__url: str
__auth: Optional[tuple] | AWS4Auth
__verify_ssl: bool

def __init__(self, url: str, auth: tuple | AWS4Auth = None, verify_ssl: bool = True):
self.__url = url
# Normalize url value to have trailing slash
if not url.endswith("/"):
self.__url += "/"
self.__auth = auth
self.__verify_ssl = verify_ssl

def __eq__(self, obj):
return isinstance(obj, EndpointInfo) and \
self.__url == obj.__url and \
self.__auth == obj.__auth and \
self.__verify_ssl == obj.__verify_ssl

def add_path(self, path: str) -> str:
# Remove leading slash if present
if path.startswith("/"):
path = path[1:]
return self.__url + path

def get_url(self) -> str:
return self.__url

def get_auth(self) -> Optional[tuple] | AWS4Auth:
return self.__auth

def is_verify_ssl(self) -> bool:
return self.__verify_ssl
156 changes: 156 additions & 0 deletions FetchMigration/python/endpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import re
from typing import Optional

from requests_aws4auth import AWS4Auth
from botocore.session import Session

from endpoint_info import EndpointInfo

# Constants
SOURCE_KEY = "source"
SINK_KEY = "sink"
SUPPORTED_PLUGINS = ["opensearch", "elasticsearch"]
HOSTS_KEY = "hosts"
INSECURE_KEY = "insecure"
CONNECTION_KEY = "connection"
DISABLE_AUTH_KEY = "disable_authentication"
USER_KEY = "username"
PWD_KEY = "password"
AWS_SIGV4_KEY = "aws_sigv4"
AWS_REGION_KEY = "aws_region"
AWS_CONFIG_KEY = "aws"
AWS_CONFIG_REGION_KEY = "region"
IS_SERVERLESS_KEY = "serverless"
ES_SERVICE_NAME = "es"
AOSS_SERVICE_NAME = "aoss"
URL_REGION_PATTERN = re.compile(r"([\w-]*)\.(es|aoss)\.amazonaws\.com")


def __get_url(plugin_config: dict) -> str:
# "hosts" can be a simple string, or an array of hosts for Logstash to hit.
# This tool needs one accessible host, so we pick the first entry in the latter case.
return plugin_config[HOSTS_KEY][0] if isinstance(plugin_config[HOSTS_KEY], list) else plugin_config[HOSTS_KEY]


# Helper function that attempts to extract the AWS region from a URL,
# assuming it is of the form *.<region>.<service>.amazonaws.com
def __derive_aws_region_from_url(url: str) -> Optional[str]:
match = URL_REGION_PATTERN.search(url)
if match:
# Index 0 returns the entire match, index 1 returns only the first group
return match.group(1)
return None


def get_aws_region(plugin_config: dict) -> str:
if plugin_config.get(AWS_SIGV4_KEY, False) and plugin_config.get(AWS_REGION_KEY, None) is not None:
return plugin_config[AWS_REGION_KEY]
elif plugin_config.get(AWS_CONFIG_KEY, None) is not None:
aws_config = plugin_config[AWS_CONFIG_KEY]
if not isinstance(aws_config, dict):
raise ValueError("Unexpected value for 'aws' configuration")
elif aws_config.get(AWS_CONFIG_REGION_KEY, None) is not None:
return aws_config[AWS_CONFIG_REGION_KEY]
# Region not explicitly defined, attempt to derive from URL
derived_region = __derive_aws_region_from_url(__get_url(plugin_config))
if derived_region is None:
raise ValueError("No region configured for AWS SigV4 auth, or derivable from host URL")
return derived_region


def __check_supported_endpoint(section_config: dict) -> Optional[tuple]:
for supported_type in SUPPORTED_PLUGINS:
if supported_type in section_config:
return supported_type, section_config[supported_type]


# This config key may be either directly in the main dict (for sink)
# or inside a nested dict (for source). The default value is False.
def is_insecure(plugin_config: dict) -> bool:
if INSECURE_KEY in plugin_config:
return plugin_config[INSECURE_KEY]
elif CONNECTION_KEY in plugin_config and INSECURE_KEY in plugin_config[CONNECTION_KEY]:
return plugin_config[CONNECTION_KEY][INSECURE_KEY]
return False


def validate_pipeline(pipeline: dict):
if SOURCE_KEY not in pipeline:
raise ValueError("Missing source configuration in Data Prepper pipeline YAML")
if SINK_KEY not in pipeline:
raise ValueError("Missing sink configuration in Data Prepper pipeline YAML")


def validate_auth(plugin_name: str, plugin_config: dict):
# If auth is disabled, no further validation is required
if plugin_config.get(DISABLE_AUTH_KEY, False):
return
# If AWS SigV4 is configured, validate region
if plugin_config.get(AWS_SIGV4_KEY, False) or AWS_CONFIG_KEY in plugin_config:
# Raises a ValueError if region cannot be derived
get_aws_region(plugin_config)
return
# Validate basic auth
elif USER_KEY not in plugin_config:
raise ValueError("Invalid auth configuration (no username) for plugin: " + plugin_name)
elif PWD_KEY not in plugin_config:
raise ValueError("Invalid auth configuration (no password for username) for plugin: " + plugin_name)


def get_supported_endpoint_config(pipeline_config: dict, section_key: str) -> tuple:
# The value of each key may be a single plugin (as a dict) or a list of plugin configs
supported_tuple = tuple()
if isinstance(pipeline_config[section_key], dict):
supported_tuple = __check_supported_endpoint(pipeline_config[section_key])
elif isinstance(pipeline_config[section_key], list):
for entry in pipeline_config[section_key]:
supported_tuple = __check_supported_endpoint(entry)
# Break out of the loop at the first supported type
if supported_tuple:
break
if not supported_tuple:
raise ValueError("Could not find any supported endpoints in section: " + section_key)
# First tuple value is the plugin name, second value is the plugin config dict
return supported_tuple[0], supported_tuple[1]


def get_aws_sigv4_auth(region: str, is_serverless: bool = False) -> AWS4Auth:
credentials = Session().get_credentials()
if not credentials:
raise ValueError("Unable to fetch AWS session credentials for SigV4 auth")
if is_serverless:
return AWS4Auth(region=region, service=AOSS_SERVICE_NAME, refreshable_credentials=credentials)
else:
return AWS4Auth(region=region, service=ES_SERVICE_NAME, refreshable_credentials=credentials)


def get_auth(plugin_config: dict) -> Optional[tuple] | AWS4Auth:
# Basic auth
if USER_KEY in plugin_config and PWD_KEY in plugin_config:
return plugin_config[USER_KEY], plugin_config[PWD_KEY]
elif plugin_config.get(AWS_SIGV4_KEY, False) or AWS_CONFIG_KEY in plugin_config:
is_serverless = False
# OpenSearch Serverless uses a different service name
if AWS_CONFIG_KEY in plugin_config:
aws_config = plugin_config[AWS_CONFIG_KEY]
if isinstance(aws_config, dict) and aws_config.get(IS_SERVERLESS_KEY, False):
is_serverless = True
region = get_aws_region(plugin_config)
return get_aws_sigv4_auth(region, is_serverless)
return None


def get_endpoint_info_from_plugin_config(plugin_config: dict) -> EndpointInfo:
# verify boolean will be the inverse of the insecure SSL key, if present
should_verify = not is_insecure(plugin_config)
return EndpointInfo(__get_url(plugin_config), get_auth(plugin_config), should_verify)


def get_endpoint_info_from_pipeline_config(pipeline_config: dict, section_key: str) -> EndpointInfo:
# Raises a ValueError if no supported endpoints are found
plugin_name, plugin_config = get_supported_endpoint_config(pipeline_config, section_key)
if HOSTS_KEY not in plugin_config:
raise ValueError("No hosts defined for plugin: " + plugin_name)
# Raises a ValueError if there an error in the auth configuration
validate_auth(plugin_name, plugin_config)
return get_endpoint_info_from_plugin_config(plugin_config)
11 changes: 6 additions & 5 deletions FetchMigration/python/fetch_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import yaml

import endpoint_utils
import metadata_migration
import migration_monitor
from fetch_orchestrator_params import FetchOrchestratorParams
Expand Down Expand Up @@ -40,12 +41,12 @@ def update_target_host(dp_config: dict, target_host: str):
# We expect the Data Prepper pipeline to only have a single top-level value
pipeline_config = next(iter(dp_config.values()))
# The entire pipeline will be validated later
if metadata_migration.SINK_KEY in pipeline_config:
if endpoint_utils.SINK_KEY in pipeline_config:
# throws ValueError if no supported endpoints are found
plugin_name, plugin_config = metadata_migration.get_supported_endpoint(pipeline_config,
metadata_migration.SINK_KEY)
plugin_config[metadata_migration.HOSTS_KEY] = [target_with_protocol]
pipeline_config[metadata_migration.SINK_KEY] = [{plugin_name: plugin_config}]
plugin_name, plugin_config = endpoint_utils.get_supported_endpoint_config(pipeline_config,
endpoint_utils.SINK_KEY)
plugin_config[endpoint_utils.HOSTS_KEY] = [target_with_protocol]
pipeline_config[endpoint_utils.SINK_KEY] = [{plugin_name: plugin_config}]


def write_inline_pipeline(pipeline_file_path: str, inline_pipeline: str, inline_target_host: Optional[str]):
Expand Down
14 changes: 8 additions & 6 deletions FetchMigration/python/index_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@


def fetch_all_indices(endpoint: EndpointInfo) -> dict:
actual_endpoint = endpoint.url + __ALL_INDICES_ENDPOINT
resp = requests.get(actual_endpoint, auth=endpoint.auth, verify=endpoint.verify_ssl)
all_indices_url: str = endpoint.add_path(__ALL_INDICES_ENDPOINT)
resp = requests.get(all_indices_url, auth=endpoint.get_auth(), verify=endpoint.is_verify_ssl())
result = dict(resp.json())
for index in list(result.keys()):
# Remove system indices
Expand All @@ -31,19 +31,21 @@ def fetch_all_indices(endpoint: EndpointInfo) -> dict:

def create_indices(indices_data: dict, endpoint: EndpointInfo):
for index in indices_data:
actual_endpoint = endpoint.url + index
index_endpoint = endpoint.add_path(index)
data_dict = dict()
data_dict[SETTINGS_KEY] = indices_data[index][SETTINGS_KEY]
data_dict[MAPPINGS_KEY] = indices_data[index][MAPPINGS_KEY]
try:
resp = requests.put(actual_endpoint, auth=endpoint.auth, verify=endpoint.verify_ssl, json=data_dict)
resp = requests.put(index_endpoint, auth=endpoint.get_auth(), verify=endpoint.is_verify_ssl(),
json=data_dict)
resp.raise_for_status()
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Failed to create index [{index}] - {e!s}")


def doc_count(indices: set, endpoint: EndpointInfo) -> int:
actual_endpoint = endpoint.url + ','.join(indices) + __COUNT_ENDPOINT
resp = requests.get(actual_endpoint, auth=endpoint.auth, verify=endpoint.verify_ssl)
count_endpoint_suffix: str = ','.join(indices) + __COUNT_ENDPOINT
doc_count_endpoint: str = endpoint.add_path(count_endpoint_suffix)
resp = requests.get(doc_count_endpoint, auth=endpoint.get_auth(), verify=endpoint.is_verify_ssl())
result = dict(resp.json())
return int(result[COUNT_KEY])
Loading

0 comments on commit 8989f00

Please sign in to comment.