Skip to content
This repository has been archived by the owner on Apr 11, 2024. It is now read-only.

Commit

Permalink
[Fetch Migration] Added total document count to report (opensearch-pr…
Browse files Browse the repository at this point in the history
…oject#261)

This change allows the Index Configuration Tool report to print the total number of documents that will be migrated, based on the indices identified for creation. The change includes a new API call under index_operations.py and updated unit tests. This commit also includes some other minor changes:
* The output YAML file is now optional, allowing users to print a report without producing a YAML file. However, one of --report or an output YAML file path are required - omitting both will result in a ValueError
* A new EndpointInfo dataclass has been introduced to encapsulate endpoint information (URL, auth and SSL verification flag) for source and target
* The default output of the tool has been changed to dump the total document count followed by the list of created indices. -r / --report should be specified to produce human-readable output. All other print statements have been removed.

---------

Signed-off-by: Kartik Ganesh <[email protected]>
  • Loading branch information
kartg authored Aug 16, 2023
1 parent edfb16a commit 2c8e3ab
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 59 deletions.
8 changes: 8 additions & 0 deletions FetchMigration/index_configuration_tool/endpoint_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass


@dataclass
class EndpointInfo:
url: str
auth: tuple = None
verify_ssl: bool = True
28 changes: 18 additions & 10 deletions FetchMigration/index_configuration_tool/index_operations.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import sys
from typing import Optional

import requests

from endpoint_info import EndpointInfo

# Constants
SETTINGS_KEY = "settings"
MAPPINGS_KEY = "mappings"
COUNT_KEY = "count"
__INDEX_KEY = "index"
__ALL_INDICES_ENDPOINT = "*"
__COUNT_ENDPOINT = "/_count"
__INTERNAL_SETTINGS_KEYS = ["creation_date", "uuid", "provided_name", "version", "store"]


def fetch_all_indices(endpoint: str, optional_auth: Optional[tuple] = None, verify: bool = True) -> dict:
actual_endpoint = endpoint + __ALL_INDICES_ENDPOINT
resp = requests.get(actual_endpoint, auth=optional_auth, verify=verify)
def fetch_all_indices(endpoint: EndpointInfo) -> dict:
actual_endpoint = endpoint.url + __ALL_INDICES_ENDPOINT
resp = requests.get(actual_endpoint, auth=endpoint.auth, verify=endpoint.verify_ssl)
# Remove internal settings
result = dict(resp.json())
for index in result:
Expand All @@ -24,14 +25,21 @@ def fetch_all_indices(endpoint: str, optional_auth: Optional[tuple] = None, veri
return result


def create_indices(indices_data: dict, endpoint: str, auth_tuple: Optional[tuple]):
def create_indices(indices_data: dict, endpoint: EndpointInfo):
for index in indices_data:
actual_endpoint = endpoint + index
actual_endpoint = endpoint.url + index
data_dict = dict()
data_dict[SETTINGS_KEY] = indices_data[index][SETTINGS_KEY]
data_dict[MAPPINGS_KEY] = indices_data[index][MAPPINGS_KEY]
try:
resp = requests.put(actual_endpoint, auth=auth_tuple, json=data_dict)
resp = requests.put(actual_endpoint, auth=endpoint.auth, verify=endpoint.verify_ssl, json=data_dict)
resp.raise_for_status()
except requests.exceptions.RequestException as e:
print(f"Failed to create index [{index}] - {e!s}", file=sys.stderr)
raise RuntimeError(f"Failed to create index [{index}] - {e!s}")


def doc_count(indices: set, endpoint: EndpointInfo) -> int:
actual_endpoint = endpoint.url + ','.join(indices) + __COUNT_ENDPOINT
resp = requests.get(actual_endpoint, auth=endpoint.auth, verify=endpoint.verify_ssl)
result = dict(resp.json())
return int(result[COUNT_KEY])
72 changes: 44 additions & 28 deletions FetchMigration/index_configuration_tool/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import utils

# Constants
from endpoint_info import EndpointInfo

SUPPORTED_ENDPOINTS = ["opensearch", "elasticsearch"]
SOURCE_KEY = "source"
SINK_KEY = "sink"
Expand Down Expand Up @@ -36,19 +38,14 @@ def get_auth(input_data: dict) -> Optional[tuple]:
return input_data[USER_KEY], input_data[PWD_KEY]


def get_endpoint_info(plugin_config: dict) -> tuple:
def get_endpoint_info(plugin_config: dict) -> EndpointInfo:
# "hosts" can be a simple string, or an array of hosts for Logstash to hit.
# This tool needs one accessible host, so we pick the first entry in the latter case.
endpoint = plugin_config[HOSTS_KEY][0] if type(plugin_config[HOSTS_KEY]) is list else plugin_config[HOSTS_KEY]
endpoint += "/"
return endpoint, get_auth(plugin_config)


def fetch_all_indices_by_plugin(plugin_config: dict) -> dict:
endpoint, auth_tuple = get_endpoint_info(plugin_config)
url = plugin_config[HOSTS_KEY][0] if type(plugin_config[HOSTS_KEY]) is list else plugin_config[HOSTS_KEY]
url += "/"
# verify boolean will be the inverse of the insecure SSL key, if present
should_verify = not is_insecure(plugin_config)
return index_operations.fetch_all_indices(endpoint, auth_tuple, should_verify)
return EndpointInfo(url, get_auth(plugin_config), should_verify)


def check_supported_endpoint(config: dict) -> Optional[tuple]:
Expand Down Expand Up @@ -112,7 +109,6 @@ def write_output(yaml_data: dict, new_indices: set, output_path: str):
source_config[INDICES_KEY] = source_indices
with open(output_path, 'w') as out_file:
yaml.dump(yaml_data, out_file)
print("Wrote output YAML pipeline to: " + output_path)


# Computes differences in indices between source and target.
Expand All @@ -138,44 +134,64 @@ def get_index_differences(source: dict, target: dict) -> tuple[set, set, set]:

# The order of data in the tuple is:
# (indices to create), (identical indices), (indices with conflicts)
def print_report(index_differences: tuple[set, set, set]): # pragma no cover
def print_report(index_differences: tuple[set, set, set], count: int): # pragma no cover
print("Identical indices in the target cluster (no changes will be made): " +
utils.string_from_set(index_differences[1]))
print("Indices in target cluster with conflicting settings/mappings: " +
utils.string_from_set(index_differences[2]))
print("Indices to create: " + utils.string_from_set(index_differences[0]))
print("Total documents to be moved: " + str(count))


def dump_count_and_indices(count: int, indices: set): # pragma no cover
print(count)
for index_name in indices:
print(index_name)


def compute_endpoint_and_fetch_indices(config: dict, key: str) -> tuple[EndpointInfo, dict]:
endpoint = get_supported_endpoint(config, key)
# Endpoint is a tuple of (type, config)
endpoint_info = get_endpoint_info(endpoint[1])
indices = index_operations.fetch_all_indices(endpoint_info)
return endpoint_info, indices


def run(args: argparse.Namespace) -> None:
# Sanity check
if not args.report and len(args.output_file) == 0:
raise ValueError("No output file specified")
# Parse and validate pipelines YAML file
with open(args.config_file_path, 'r') as pipeline_file:
dp_config = yaml.safe_load(pipeline_file)
# We expect the Data Prepper pipeline to only have a single top-level value
pipeline_config = next(iter(dp_config.values()))
validate_pipeline_config(pipeline_config)
# Endpoint is a tuple of (type, config)
endpoint = get_supported_endpoint(pipeline_config, SOURCE_KEY)
# Fetch all indices from source cluster
source_indices = fetch_all_indices_by_plugin(endpoint[1])
# Fetch all indices from target cluster
# TODO Refactor this to avoid duplication with fetch_all_indices_by_plugin
endpoint = get_supported_endpoint(pipeline_config, SINK_KEY)
target_endpoint, target_auth = get_endpoint_info(endpoint[1])
target_indices = index_operations.fetch_all_indices(target_endpoint, target_auth)
# Fetch EndpointInfo and indices
source_endpoint_info, source_indices = compute_endpoint_and_fetch_indices(pipeline_config, SOURCE_KEY)
target_endpoint_info, target_indices = compute_endpoint_and_fetch_indices(pipeline_config, SINK_KEY)
# Compute index differences and print report
diff = get_index_differences(source_indices, target_indices)
if args.report:
print_report(diff)
# The first element in the tuple is the set of indices to create
indices_to_create = diff[0]
doc_count = 0
if indices_to_create:
doc_count = index_operations.doc_count(indices_to_create, source_endpoint_info)
if args.report:
print_report(diff, doc_count)
if indices_to_create:
if not args.report:
dump_count_and_indices(doc_count, indices_to_create)
# Write output YAML
write_output(dp_config, indices_to_create, args.output_file)
if len(args.output_file) > 0:
write_output(dp_config, indices_to_create, args.output_file)
if args.report:
print("Wrote output YAML pipeline to: " + args.output_file)
if not args.dryrun:
index_data = dict()
for index_name in indices_to_create:
index_data[index_name] = source_indices[index_name]
index_operations.create_indices(index_data, target_endpoint, target_auth)
index_operations.create_indices(index_data, target_endpoint_info)


if __name__ == '__main__': # pragma no cover
Expand All @@ -191,20 +207,20 @@ def run(args: argparse.Namespace) -> None:
"along with indices that are identical or have conflicting settings/mappings.",
formatter_class=argparse.RawTextHelpFormatter
)
# Positional, required arguments
# Required positional argument
arg_parser.add_argument(
"config_file_path",
help="Path to the Data Prepper pipeline YAML file to parse for source and target endpoint information"
)
# Optional positional argument
arg_parser.add_argument(
"output_file",
nargs='?', default="",
help="Output path for the Data Prepper pipeline YAML file that will be generated"
)
# Optional arguments
# Flags
arg_parser.add_argument("--report", "-r", action="store_true",
help="Print a report of the index differences")
arg_parser.add_argument("--dryrun", action="store_true",
help="Skips the actual creation of indices on the target cluster")
print("\n##### Starting index configuration tool... #####\n")
run(arg_parser.parse_args())
print("\n##### Index configuration tool has completed! #####\n")
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from responses import matchers

import index_operations
from endpoint_info import EndpointInfo
from tests import test_constants


Expand All @@ -15,7 +16,7 @@ def test_fetch_all_indices(self):
# Set up GET response
responses.get(test_constants.SOURCE_ENDPOINT + "*", json=test_constants.BASE_INDICES_DATA)
# Now send request
index_data = index_operations.fetch_all_indices(test_constants.SOURCE_ENDPOINT)
index_data = index_operations.fetch_all_indices(EndpointInfo(test_constants.SOURCE_ENDPOINT))
self.assertEqual(3, len(index_data.keys()))
# Test that internal data has been filtered, but non-internal data is retained
index_settings = index_data[test_constants.INDEX1_NAME][test_constants.SETTINGS_KEY]
Expand All @@ -33,7 +34,7 @@ def test_create_indices(self):
match=[matchers.json_params_matcher(test_data[test_constants.INDEX2_NAME])])
responses.put(test_constants.TARGET_ENDPOINT + test_constants.INDEX3_NAME,
match=[matchers.json_params_matcher(test_data[test_constants.INDEX3_NAME])])
index_operations.create_indices(test_data, test_constants.TARGET_ENDPOINT, None)
index_operations.create_indices(test_data, EndpointInfo(test_constants.TARGET_ENDPOINT))

@responses.activate
def test_create_indices_exception(self):
Expand All @@ -43,7 +44,18 @@ def test_create_indices_exception(self):
del test_data[test_constants.INDEX3_NAME]
responses.put(test_constants.TARGET_ENDPOINT + test_constants.INDEX1_NAME,
body=requests.Timeout())
index_operations.create_indices(test_data, test_constants.TARGET_ENDPOINT, None)
self.assertRaises(RuntimeError, index_operations.create_indices, test_data,
EndpointInfo(test_constants.TARGET_ENDPOINT))

@responses.activate
def test_doc_count(self):
test_indices = {test_constants.INDEX1_NAME, test_constants.INDEX2_NAME}
expected_count_endpoint = test_constants.SOURCE_ENDPOINT + ",".join(test_indices) + "/_count"
mock_count_response = {"count": "10"}
responses.get(expected_count_endpoint, json=mock_count_response)
# Now send request
count_value = index_operations.doc_count(test_indices, EndpointInfo(test_constants.SOURCE_ENDPOINT))
self.assertEqual(10, count_value)


if __name__ == '__main__':
Expand Down
54 changes: 36 additions & 18 deletions FetchMigration/index_configuration_tool/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,26 @@ def test_get_endpoint_info(self):
# Simple base case
test_config = create_plugin_config([host_input])
result = main.get_endpoint_info(test_config)
self.assertEqual(expected_endpoint, result[0])
self.assertIsNone(result[1])
self.assertEqual(expected_endpoint, result.url)
self.assertIsNone(result.auth)
self.assertTrue(result.verify_ssl)
# Invalid auth config
test_config = create_plugin_config([host_input], test_user)
result = main.get_endpoint_info(test_config)
self.assertEqual(expected_endpoint, result[0])
self.assertIsNone(result[1])
self.assertEqual(expected_endpoint, result.url)
self.assertIsNone(result.auth)
# Valid auth config
test_config = create_plugin_config([host_input], user=test_user, password=test_password)
result = main.get_endpoint_info(test_config)
self.assertEqual(expected_endpoint, result[0])
self.assertEqual(test_user, result[1][0])
self.assertEqual(test_password, result[1][1])
self.assertEqual(expected_endpoint, result.url)
self.assertEqual(test_user, result.auth[0])
self.assertEqual(test_password, result.auth[1])
# Array of hosts uses the first entry
test_config = create_plugin_config([host_input, "other_host"], test_user, test_password)
result = main.get_endpoint_info(test_config)
self.assertEqual(expected_endpoint, result[0])
self.assertEqual(test_user, result[1][0])
self.assertEqual(test_password, result[1][1])
self.assertEqual(expected_endpoint, result.url)
self.assertEqual(test_user, result.auth[0])
self.assertEqual(test_password, result.auth[1])

def test_get_index_differences_empty(self):
# Base case should return an empty list
Expand Down Expand Up @@ -225,17 +226,18 @@ def test_validate_pipeline_config_happy_case(self):
test_config = next(iter(self.loaded_pipeline_config.values()))
main.validate_pipeline_config(test_config)

@patch('index_operations.doc_count')
@patch('main.write_output')
@patch('main.print_report')
@patch('index_operations.create_indices')
@patch('index_operations.fetch_all_indices')
# Note that mock objects are passed bottom-up from the patch order above
def test_run_report(self, mock_fetch_indices: MagicMock, mock_create_indices: MagicMock,
mock_print_report: MagicMock, mock_write_output: MagicMock):
mock_print_report: MagicMock, mock_write_output: MagicMock, mock_doc_count: MagicMock):
mock_doc_count.return_value = 1
index_to_create = test_constants.INDEX3_NAME
index_with_conflict = test_constants.INDEX2_NAME
index_exact_match = test_constants.INDEX1_NAME
expected_output_path = "dummy"
# Set up expected arguments to mocks so we can verify
expected_create_payload = {index_to_create: test_constants.BASE_INDICES_DATA[index_to_create]}
# Print report accepts a tuple. The elements of the tuple
Expand All @@ -252,21 +254,26 @@ def test_run_report(self, mock_fetch_indices: MagicMock, mock_create_indices: Ma
# Set up test input
test_input = argparse.Namespace()
test_input.config_file_path = test_constants.PIPELINE_CONFIG_RAW_FILE_PATH
test_input.output_file = expected_output_path
# Default value for missing output file
test_input.output_file = ""
test_input.report = True
test_input.dryrun = False
main.run(test_input)
mock_create_indices.assert_called_once_with(expected_create_payload, test_constants.TARGET_ENDPOINT, ANY)
mock_print_report.assert_called_once_with(expected_diff)
mock_write_output.assert_called_once_with(self.loaded_pipeline_config, {index_to_create}, expected_output_path)
mock_create_indices.assert_called_once_with(expected_create_payload, ANY)
mock_doc_count.assert_called()
mock_print_report.assert_called_once_with(expected_diff, 1)
mock_write_output.assert_not_called()

@patch('index_operations.doc_count')
@patch('main.dump_count_and_indices')
@patch('main.print_report')
@patch('main.write_output')
@patch('index_operations.fetch_all_indices')
# Note that mock objects are passed bottom-up from the patch order above
def test_run_dryrun(self, mock_fetch_indices: MagicMock, mock_write_output: MagicMock,
mock_print_report: MagicMock):
mock_print_report: MagicMock, mock_dump: MagicMock, mock_doc_count: MagicMock):
index_to_create = test_constants.INDEX1_NAME
mock_doc_count.return_value = 1
expected_output_path = "dummy"
# Create mock data for indices on target
target_indices_data = copy.deepcopy(test_constants.BASE_INDICES_DATA)
Expand All @@ -281,8 +288,10 @@ def test_run_dryrun(self, mock_fetch_indices: MagicMock, mock_write_output: Magi
test_input.report = False
main.run(test_input)
mock_write_output.assert_called_once_with(self.loaded_pipeline_config, {index_to_create}, expected_output_path)
# Report should not be printed
mock_doc_count.assert_called()
# Report should not be printed, but dump should be invoked
mock_print_report.assert_not_called()
mock_dump.assert_called_once_with(mock_doc_count.return_value, {index_to_create})

@patch('yaml.dump')
def test_write_output(self, mock_dump: MagicMock):
Expand Down Expand Up @@ -311,6 +320,15 @@ def test_write_output(self, mock_dump: MagicMock):
mock_open.assert_called_once_with(expected_output_path, 'w')
mock_dump.assert_called_once_with(expected_output_data, ANY)

def test_missing_output_file_non_report(self):
# Set up test input
test_input = argparse.Namespace()
test_input.config_file_path = test_constants.PIPELINE_CONFIG_RAW_FILE_PATH
# Default value for missing output file
test_input.output_file = ""
test_input.report = False
self.assertRaises(ValueError, main.run, test_input)


if __name__ == '__main__':
unittest.main()

0 comments on commit 2c8e3ab

Please sign in to comment.