diff --git a/deltacat/utils/ray_utils/retry_handler/AIMD_based_batch_scaling_strategy.py b/deltacat/utils/ray_utils/retry_handler/AIMD_based_batch_scaling_strategy.py new file mode 100644 index 00000000..2b46f908 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/AIMD_based_batch_scaling_strategy.py @@ -0,0 +1,49 @@ +from typing import List, Any +from deltacat.utils.ray_utils.retry_handler.batch_scaling_interface import BatchScalingInterface +class AIMDBasedBatchScalingStrategy(BatchScalingInterface): + """ + Default batch scaling parameters for if the client does not provide their own batch_scaling parameters + """ + def __init__(self, additive_increase: int, multiplicative_decrease: float): + self.task_infos = [] + self.batch_index = 0 + self.batch_size = None + self.max_batch_size = None + self.min_batch_size = None + self.additive_increase = additive_increase + self.multiplicative_decrease = multiplicative_decrease + def init_tasks(self, initial_batch_size: int, max_batch_size: int, min_batch_size: int, task_infos: List[TaskInfoObject])-> None: + """ + Setup AIMD scaling for the batches as the default + """ + self.task_infos = task_infos + self.batch_size = initial_batch_size + self.max_batch_size = max_batch_size + self.min_batch_size = min_batch_size + + + def has_next_batch(self) -> bool: + """ + Returns the list of tasks included in the next batch of whatever size based on AIMD + """ + return self.batch_index < len(self.task_infos) + + + def next_batch(self) -> List[TaskInfoObject]: + """ + If there are no more tasks to execute that can not create a batch, return False + """ + batch_end = min(self.batch_index + self.batch_size, len(self.task_infos)) + batch = self.task_infos[self.batch_index:batch_end] + self.batch_index = batch_end + return batch + + def mark_task_complete(self, task_info: TaskInfoObject): + task_info.completed = True + + def increase_batch_size(self): + self.batch_size = min(self.batch_size + self.additive_increase, self.max_batch_size) + + + def decrease_batch_size(self): + self.batch_size = max(self.batch_size * self.multiplicative_decrease, self.min_batch_size) \ No newline at end of file diff --git a/deltacat/utils/ray_utils/retry_handler/README.md b/deltacat/utils/ray_utils/retry_handler/README.md new file mode 100644 index 00000000..466d9d6d --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/README.md @@ -0,0 +1,27 @@ +This module represents a straggler detection and retry handler framework + +Within retry_strategy_config.py, the client can provide 3 parameters to start_tasks_execution to perform retries and detect stragglers +Params: +1. ray_remote_task_info: A list of Ray task objects +2. scaling_strategy: Batch scaling parameters for how many tasks to execute per batch (Optional) + a. If not provided, a default AIMD (additive increase, multiplicative decrease) strategy will be assigned for retries +3. straggler_detection: Client-provided class that holds logic for how they want to detect straggler tasks (Optional) + a. Client algorithm must inherit the interface for detection which will be used in wait_and_get_results + +Use cases: +1. Notifying progress + This will be done through ProgressNotifierInterface. The client can implement has_progress and send_progress from the interface + to recieve updates on task level progress. This can be an SNSQueue or any type of indicator the client may choose. +2. Detecting stragglers + Given the straggler detection algorithm implemented by StragglerDetectionInterface, the method is_straggler will inform + the customer if the current node is a straggler according to their own logic. In order to make their decision, we will provide them + with TaskContext that contains fields and data that the client can use to decide if a task is a straggler or not. +3. Retrying retryable exceptions + Within the failure directory, there are common errors that are retryable and when detected as an instance + of the retryable class, will cause the task to be retried when the exception is caught. If the client would like + to create their own exceptions to be handles, they can create a class that is an extension of retryable_error or + non_retryable_error and the framework should handle it based on the configuration strategy. + + + + diff --git a/deltacat/utils/ray_utils/retry_handler/batch_scaling_interface.py b/deltacat/utils/ray_utils/retry_handler/batch_scaling_interface.py new file mode 100644 index 00000000..dca5c339 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/batch_scaling_interface.py @@ -0,0 +1,42 @@ +from typing import List, Any, Protocol +from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject +class BatchScalingInterface(Protocol): + """ + Interface for a generic batch scaling that the client can provide. + """ + def init_tasks(self, initial_batch_size: int, max_batch_size: int, min_batch_size: int, task_infos: List[TaskInfoObject]) -> None: + """ + Loads all tasks to be executed for retry batching + """ + pass + def has_next_batch(self) -> bool: + """ + Returns true if there are tasks remaining in the overall List of tasks to create a new batch + """ + pass + def next_batch(self, task_info: TaskInfoObject) -> List: + """ + Gets the next batch to execute on + """ + pass + def mark_task_complete(self, task_info: TaskInfoObject) -> None: + """ + If the task has been completed, mark some field of it as true + so we know what tasks are completed and what need to be executed + """ + pass + + def increase_batch_size(self) -> None: + """ + Increase the batch size by some amount according to client specifications + :return: + """ + pass + + def decrease_batch_size(self) -> None: + """ + Decrease the batch size by some amount according to client specifications + :return: + """ + pass + diff --git a/deltacat/utils/ray_utils/retry_handler/exception_util.py b/deltacat/utils/ray_utils/retry_handler/exception_util.py new file mode 100644 index 00000000..53087187 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/exception_util.py @@ -0,0 +1,16 @@ +from typing import List, Optional +from ray_manager.models.ray_remote_task_exception_retry_strategy_config import RayRemoteTaskExceptionRetryConfig +def get_retry_strategy_config_for_known_exception(exception: Exception, + exception_retry_strategy_configs: List[RayRemoteTaskExceptionRetryConfig]) -> Optional[RayRemoteTaskExceptionRetryConfig]: + """ + Checks whether the exception seen is recognized as a retryable error or not + """ + for exception_retry_strategy_config in exception_retry_strategy_configs: + if type(exception) == type(exception_retry_strategy_config.exception): + return exception_retry_strategy_config + + for exception_retry_strategy_config in exception_retry_strategy_configs: + if isinstance(exception, type(exception_retry_strategy_config.exception)): + return exception_retry_strategy_config + + return None \ No newline at end of file diff --git a/deltacat/utils/ray_utils/retry_handler/failures/aws_security_token_rate_exceeded_exception.py b/deltacat/utils/ray_utils/retry_handler/failures/aws_security_token_rate_exceeded_exception.py new file mode 100644 index 00000000..f64978c2 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/failures/aws_security_token_rate_exceeded_exception.py @@ -0,0 +1,6 @@ +from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError + +class AWSSecurityTokenRateExceededException(RetryableError): + + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/deltacat/utils/ray_utils/retry_handler/failures/cairns_client_exception.py b/deltacat/utils/ray_utils/retry_handler/failures/cairns_client_exception.py new file mode 100644 index 00000000..42e6a08b --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/failures/cairns_client_exception.py @@ -0,0 +1,6 @@ +from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError + +class CairnsClientException(RetryableError): + + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/deltacat/utils/ray_utils/retry_handler/failures/non_retryable_error.py b/deltacat/utils/ray_utils/retry_handler/failures/non_retryable_error.py new file mode 100644 index 00000000..636fc8c1 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/failures/non_retryable_error.py @@ -0,0 +1,7 @@ +from exceptions import Exception +class NonRetryableError(Exception): + """ + Class represents a non-retryable error + """ + def __init__(self, *args: object): + super().__init__(*args) diff --git a/deltacat/utils/ray_utils/retry_handler/failures/retryable_error.py b/deltacat/utils/ray_utils/retry_handler/failures/retryable_error.py new file mode 100644 index 00000000..01124240 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/failures/retryable_error.py @@ -0,0 +1,7 @@ +from exceptions import Exception +class RetryableError(Exception): + """ + Class for errors that can be retried + """ + def __init__(self, *args: object) --> None: + super().__init__(*args) diff --git a/deltacat/utils/ray_utils/retry_handler/progress_notifier_interface.py b/deltacat/utils/ray_utils/retry_handler/progress_notifier_interface.py new file mode 100644 index 00000000..fac7c81d --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/progress_notifier_interface.py @@ -0,0 +1,17 @@ +from typing import List, Protocol +from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject +class ProgressNotifierInterface(Protocol): + """ + Interface for client injected progress notification system. + """ + def has_heartbeat(self, task_info: TaskInfoObject) -> bool: + """ + Sends progress of current task to parent task + """ + pass + def send_heartbeat(self, parent_task_info: TaskInfoObject) -> bool: + """ + Tells parent task if the current task has a heartbeat or not + """ + pass + diff --git a/deltacat/utils/ray_utils/retry_handler/ray_task_submission_handler.py b/deltacat/utils/ray_utils/retry_handler/ray_task_submission_handler.py new file mode 100644 index 00000000..a03e6102 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/ray_task_submission_handler.py @@ -0,0 +1,196 @@ +from __future__ import annotations +from typing import Any, Dict, List, cast, Optional +from deltacat.utils.ray_utils.retry_handler.ray_remote_tasks_batch_scaling_strategy import RayRemoteTasksBatchScalingStrategy +import ray +import time +import logging +from deltacat.logs import configure_logger +from deltacat.utils.ray_utils.retry_handler.task_execution_error import RayRemoteTaskExecutionError +from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject +from deltacat.utils.ray_utils.retry_handler.retry_strategy_config import get_retry_strategy_config_for_known_exception + +logger = configure_logger(logging.getLogger(__name__)) + +@ray.remote +def submit_single_task(taskObj: TaskInfoObject, TaskContext: Optional[Interface] = None) -> Any: + """ + Submits a single task for execution, handles any exceptions that may occur during execution, + and applies appropriate retry strategies if they are defined. + """ + try: + taskObj.attempt_count += 1 + curr_attempt = taskObj.attempt_count + logger.debug(f"Executing the submitted Ray remote task as part of attempt number: {current_attempt_number}") + return taskObj.task_callable(taskObj.task_input) + except (Exception) as exception: + exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, taskObj.exception_retry_strategy_configs) + if exception_retry_strategy_config is not None: + return RayRemoteTaskExecutionError(exception_retry_strategy_config.exception, taskObj) + + logger.error(f"The exception thrown by submitted Ray task during attempt number: {current_attempt_number} is non-retryable or unexpected, hence throwing Non retryable exception: {exception}") + raise UnexpectedRayTaskError(str(exception)) + +class RayTaskSubmissionHandler: + """ + Starts execution of all given a list of Ray tasks with optional arguments: scaling strategy and straggler detection + """ + def start_tasks_execution(self, + ray_remote_task_infos: List[TaskInfoObject], + scaling_strategy: Optional[BatchScalingStrategy] = None, + straggler_detection: Optional[StragglerDetectionInterface] = None, + retry_strategy: Optional[RetryTaskInterface], + task_context: Optional[TaskContext]) -> None: + """ + Prepares and initiates the execution of a batch of tasks and can optionally support + custom client batch scaling, straggler detection, and task context + """ + if scaling_strategy is None: + scaling_strategy = AIMDBasedBatchScalingStrategy(ray_remote_task_infos) + if retry_strategy is None: + retry_strategy = RetryTaskDefault(max_retries = 3) + + active_tasks = [] + + while scaling_strategy.has_next_batch(): + current_batch = scaling_strategy.next_batch() + for task in current_batch: + try: + self._submit_tasks(task) + active_tasks.append(task) + except Exception as e: + if retry_strategy.should_retry(task, e): + retry_strategy.retry(task, e) + continue + else: + raise #? not sure what to do if the error isnt retryable + completed_tasks = self._wait_and_get_all_task_results(active_tasks) + + for task in completed_tasks: + scaling_strategy.mark_task_complete(task) + active_tasks.remove(task) + + if all(task.completed for task in current_batch): + scaling_strategy.increase_batch_size() + else: + scaling_strategy.decrease_batch_size() + + #handle strags + if straggler_detection is not None: + for task in active_tasks: #tasks that are still running + if straggler_detection.is_straggler(task, task_context): + ray.cancel(task) + active_tasks.remove(task) + #maybe we need to requeue the cancelled task? can add back to ray_remote_task_infos + + + #call wait_and_get_all ... + #when ray returns results mark as completed --> to mark as completed we want to give a bool field to the task info object and set to true, when gets marked to true + #if success, additive increase method to batchScaling + #if failure, MD on the batch size and continue until nothing remains + #check at least 1 is completed from current batch + #mark task as completed + + #wait some time period here ? --> call to _wait_and_get_all_task_results so there is a period to collect completed tasks + #use result of wait and remove from active_tasks because it is completed + #use results of completed promises compared to total tasks in batch to determine batch scaling increase or decrease + + + def _wait_and_get_all_task_results(self, straggler_detection: Optional[StragglerDetectionInterface]) -> List[Any]: + return self._get_task_results(self.num_of_submitted_tasks, straggler_detection) + + def _get_task_results(self, num_of_results: int, straggler_detection: Optional[StragglerDetectionInterface]) -> List[Any]: + """ + Gets results from a list of tasks to be executed, and catches exceptions to manage the retry strategy. + Optional: Given a StragglerDetectionInterface, can detect and handle straggler tasks according to the client logic + """ + if not self.unfinished_promises or num_of_results == 0: + return [] + elif num_of_results > len(self.unfinished_promises): + num_of_results = len(self.unfinished_promises) + + finished, self.unfinished_promises = ray.wait(self.unfinished_promises, num_of_results) + successful_results = [] + + for finished in finished: + finished_result = None + try: + finished_result = ray.get(finished) + except (Exception) as exception: + #if exception send to method handle_ray_exception to determine what to do and assign the corresp error + finished_result = self._handle_ray_exception(exception=exception, ray_remote_task_info=self.task_promise_obj_ref_to_task_info_map[str(finished_promise)] )#evaluate the exception and return the error + + if finished_result and type(finished_result) == RayRemoteTaskExecutionError: + finished_result = cast(RayRemoteTaskExecutionError, finished_result) + + if straggler_detection and straggler_detection.isStraggler(finished_result): + ray.cancel(finished_result) + exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(finished_result.exception, + finished_result.ray_remote_task_info.exception_retry_strategy_configs) + if (exception_retry_strategy_config is None or finished_result.ray_remote_task_info.num_of_attempts > exception_retry_strategy_config.max_retry_attempts): + logger.error(f"The submitted task has exhausted all the maximum retries configured and finally throws exception - {finished_result.exception}") + raise finished_result.exception + self._update_ray_remote_task_options_on_exception(finished_result.exception, finished_result.ray_remote_task_info) + self.unfinished_promises.append(self._invoke_ray_remote_task(ray_remote_task_info=finished_result.ray_remote_task_info)) + else: + successful_results.append(finished_result) + del self.task_promise_obj_ref_to_task_info_map[str(finished_promise)] + + num_of_successful_results = len(successful_results) + self.num_of_submitted_tasks_completed += num_of_successful_results + self.current_batch_size -= num_of_successful_results + + self._enqueue_new_tasks(num_of_successful_results) + + if num_of_successful_results < num_of_results: + successful_results.extend(self._get_task_results(num_of_results - num_of_successful_results)) + return successful_results + else: + return successful_results + + + def _enqueue_new_tasks(self, num_of_tasks: int) -> None: + """ + Helper method to submit a specified number of tasks + """ + new_tasks_submitted = self.remaining_ray_remote_task_infos[:num_of_tasks] + num_of_new_tasks_submitted = len(new_tasks_submitted) + self._submit_tasks(new_tasks_submitted) + self.remaining_ray_remote_task_infos = self.remaining_ray_remote_task_infos[num_of_tasks:] + self.current_batch_size += num_of_new_tasks_submitted + logger.info(f"Enqueued {num_of_new_tasks_submitted} new tasks. Current concurrency of tasks execution: {self.current_batch_size}, Current Task progress: {self.num_of_submitted_tasks_completed}/{self.num_of_submitted_tasks}") + + def _submit_tasks(self, info_objs: List[TaskInfoObject]) -> None: + for info_obj in info_objs: + time.sleep(0.005) + self.unfinished_promises.append(self._invoke_ray_remote_task(info_obj)) + #replace with ray.options + def _invoke_ray_remote_task(self, ray_remote_task_info: RayRemoteTaskInfo) -> Any: + #change to using ray.options + ray_remote_task_options_arguments = dict() + + if ray_remote_task_info.ray_remote_task_options.memory: + ray_remote_task_options_arguments['memory'] = ray_remote_task_info.ray_remote_task_options.memory + + if ray_remote_task_info.ray_remote_task_options.num_cpus: + ray_remote_task_options_arguments['num_cpus'] = ray_remote_task_info.ray_remote_task_options.num_cpus + + if ray_remote_task_info.ray_remote_task_options.placement_group: + ray_remote_task_options_arguments['placement_group'] = ray_remote_task_info.ray_remote_task_options.placement_group + + ray_remote_task_promise_obj_ref = submit_single_task.options(**ray_remote_task_options_arguments).remote(ray_remote_task_info=ray_remote_task_info) + self.task_promise_obj_ref_to_task_info_map[str(ray_remote_task_promise_obj_ref)] = ray_remote_task_info + + return ray_remote_task_promise_obj_ref + + #replace with ray.options + def _update_ray_remote_task_options_on_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo): + exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, ray_remote_task_info.exception_retry_strategy_configs) + if exception_retry_strategy_config and ray_remote_task_info.ray_remote_task_options.memory: + logger.info(f"Updating the Ray remote task options after encountering exception: {exception}") + ray_remote_task_memory_multiply_factor = exception_retry_strategy_config.ray_remote_task_memory_multiply_factor + ray_remote_task_info.ray_remote_task_options.memory *= ray_remote_task_memory_multiply_factor + logger.info(f"Updated ray remote task options Memory: {ray_remote_task_info.ray_remote_task_options.memory}") + #replace with own exceptions + def _handle_ray_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo) -> RayRemoteTaskExecutionError: + logger.error(f"Ray remote task failed with {type(exception)} Ray exception: {exception}") + if type(exception).__name__ == "AWSSecurityTokenRateExceededException(RetryableError)" \ No newline at end of file diff --git a/deltacat/utils/ray_utils/retry_handler/retry_task_default.py b/deltacat/utils/ray_utils/retry_handler/retry_task_default.py new file mode 100644 index 00000000..6c2bdf5c --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/retry_task_default.py @@ -0,0 +1,31 @@ +from typing import List, Protocol +from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject +import Exception + +class RetryTaskDefault(RetryTaskInterface): + def __init__(self, max_retries: int): + self.max_retries = max_retries + def should_retry(self, task: TaskInfoObject, exception: Exception): + """ + Given a task, determine whether it should be retried or not based on if its an instance of the RetryableError + """ + if isinstance(exception, RetryableError): + return True + + + def get_wait_time(self, task: TaskInfoObject): + """ + Determines the wait time between retries + """ + pass + + def retry(self, task: TaskInfoObject, exception: Exception): + """ + Executes retry behavior for the given exception + """ + task_id = task.task_id + if self.should_retry(task, exception): + wait_time = self.get_wait_time(task) + time.sleep(wait_time) + #increase retry count here + self.execute_task(task) \ No newline at end of file diff --git a/deltacat/utils/ray_utils/retry_handler/retry_task_interface.py b/deltacat/utils/ray_utils/retry_handler/retry_task_interface.py new file mode 100644 index 00000000..2e661457 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/retry_task_interface.py @@ -0,0 +1,22 @@ +from typing import List, Protocol +from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject +import Exception + +class RetryTaskInterface(Protocol): + def should_retry(self, task: TaskInfoObject, exception: Exception) -> bool: + """ + Given a task, determine whether it should be retried or not based on if its an instance of the RetryableError + """ + pass + + def get_wait_time(self, task: TaskInfoObject) -> int: + """ + Determines the wait time between retries + """ + pass + + def retry(self, task: TaskInfoObject, exception: Exception) -> None: + """ + Executes retry behavior for the given exception + """ + pass \ No newline at end of file diff --git a/deltacat/utils/ray_utils/retry_handler/straggler_detection_interface.py b/deltacat/utils/ray_utils/retry_handler/straggler_detection_interface.py new file mode 100644 index 00000000..ef986b3b --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/straggler_detection_interface.py @@ -0,0 +1,12 @@ +from typing import Any, Protocol +from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject +from deltacat.utils.ray_utils.retry_handler.task_context import TaskContext +class StragglerDetectionInterface(Protocol): + """ + Using TaskContext, handles the client-side implementation for straggler detection + """ + def is_straggler(self, task: TaskInfoObject, task_context: TaskContext) -> bool: + """ + Given all the info, returns whether this specific task is a straggler or not + """ + pass \ No newline at end of file diff --git a/deltacat/utils/ray_utils/retry_handler/task_context.py b/deltacat/utils/ray_utils/retry_handler/task_context.py new file mode 100644 index 00000000..b41e68e5 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/task_context.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from deltacat.utils.ray_utils.retry_handler.progress_notifier_interface import ProgressNotifierInterface +@dataclass +class TaskContext(): + """ + This class represents important info pertaining to the task that other interfaces like Straggler Detection + can use to make decisions + """ + def __init__(self, progress_notifier: ProgressNotifierInterface, timeout: float): + self.progress_notifier = progress_notifier + self.timeout = timeout diff --git a/deltacat/utils/ray_utils/retry_handler/task_exception_retry_config.py b/deltacat/utils/ray_utils/retry_handler/task_exception_retry_config.py new file mode 100644 index 00000000..76d99d9e --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/task_exception_retry_config.py @@ -0,0 +1,20 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import List +from deltacat.utils.ray_utils.retry_handler.task_constants import DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BATCH_SIZE_MULTIPLICATIVE_DECREASE_FACTOR, DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BACK_OFF_IN_MS, DEFAULT_RAY_REMOTE_TASK_BATCH_POSITIVE_FEEDBACK_BATCH_SIZE_ADDITIVE_INCREASE +class TaskExceptionRetryConfig: + """ + Determines how to handle and retry specific exceptions during task executions + """ + def __init__(self, exception: Exception, + max_retry_attempts: int = DEFAULT_MAX_RAY_REMOTE_TASK_RETRY_ATTEMPTS, + initial_back_off_in_ms: int = DEFAULT_RAY_REMOTE_TASK_RETRY_INITIAL_BACK_OFF_IN_MS, + back_off_factor: int = DEFAULT_RAY_REMOTE_TASK_RETRY_BACK_OFF_FACTOR, + ray_remote_task_memory_multiplication_factor: float = DEFAULT_RAY_REMOTE_TASK_MEMORY_MULTIPLICATION_FACTOR, + is_throttling_exception: bool = False) -> None: + self.exception = exception + self.max_retry_attempts = max_retry_attempts + self.initial_back_off_in_ms = initial_back_off_in_ms + self.back_off_factor = back_off_factor + self.ray_remote_task_memory_multiply_factor = ray_remote_task_memory_multiplication_factor + self.is_throttling_exception = is_throttling_exception diff --git a/deltacat/utils/ray_utils/retry_handler/task_execution_error.py b/deltacat/utils/ray_utils/retry_handler/task_execution_error.py new file mode 100644 index 00000000..4e3e1d30 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/task_execution_error.py @@ -0,0 +1,7 @@ +class RayRemoteTaskExecutionError: + """ + An error class that denotes the Ray Remote Task Execution Failure + """ + def __init__(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo) -> None: + self.exception = exception + self.ray_remote_task_info = ray_remote_task_info \ No newline at end of file diff --git a/deltacat/utils/ray_utils/retry_handler/task_info_object.py b/deltacat/utils/ray_utils/retry_handler/task_info_object.py new file mode 100644 index 00000000..3f852460 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/task_info_object.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import Any, Callable, List +from deltacat.utils.ray_utils.retry_handler.task_exception_retry_config import TaskExceptionRetryConfig +from deltacat.utils.ray_utils.retry_handler.task_options import RayRemoteTaskOptions + +@dataclass +class TaskInfoObject: + """ + Dataclass holding important fields representing the Task as an object + """ + def __init__(self, + task_id: str, + task_callable: Callable[[Any], [Any]], + task_input: Any, + ray_remote_task_options: RayRemoteTaskOptions = RayRemoteTaskOptions(), + task_exception_retry_config: List[TaskExceptionRetryConfig]): + self.task_complete = False + self.task_id = task_id + self.task_callable = task_callable + self.task_input = task_input + self.ray_remote_task_options = ray_remote_task_options + self.task_exception_retry_config = task_exception_retry_config + diff --git a/deltacat/utils/ray_utils/retry_handler/task_options.py b/deltacat/utils/ray_utils/retry_handler/task_options.py new file mode 100644 index 00000000..4383d605 --- /dev/null +++ b/deltacat/utils/ray_utils/retry_handler/task_options.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import Any, Optional + +@dataclass +class RayRemoteTaskOptions(): + """ + Represents the options corresponding to Ray remote task + """ + def __init__(self, + memory: Optional[float] = None, + num_cpus: Optional[int] = None, + placement_group: Optional[Any] = None) -> None: + self.memory = memory + self.num_cpus = num_cpus + self.placement_group = placement_group \ No newline at end of file