Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetryHandlerSkeleton #152

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
This module represents a straggler detection and retry handler framework

Within retry_strategy_config.py, the client can provide 3 parameters to start_tasks_execution to perform retries and detect stragglers
Params:
1. ray_remote_task_info: A list of Ray task objects
2. scaling_strategy: Batch scaling parameters for how many tasks to execute per batch (Optional)
a. If not provided, a default AIMD (additive increase, multiplicative decrease) strategy will be assigned for retries
3. straggler_detection: Client-provided class that holds logic for how they want to detect straggler tasks (Optional)
a. Client algorithm must inherit the interface for detection which will be used in wait_and_get_results
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

Use cases:
1. Notifying progress
This will be done through ProgressNotifierInterface. The client can use has_progress and send_progress
to recieve updates on task level progress. This can be an SNSQueue or any type of indicator the client may choose.
2. Detecting stragglers
Given the straggler detection algorithm implemented by StragglerDetectionInterface, the method is_straggler will inform
the customer if the current node is a straggler according to their own logic and proving them with TaskContext, the information
they might need to make that decision.
3. Retrying retryable exceptions
Within the failure directory, there are common errors that are retryable and when detected as an instance
of the retryable class, will cause the task to be retried when the exception is caught. If the client would like
to create their own exceptions to be handles, they can create a class that is an extension of retryable_error or
non_retryable_error and the framework should handle it based on the configuration strategy.




28 changes: 28 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/batch_scaling_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List, Any, Protocol
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
class BatchScalingInterface(Protocol):
"""
Interface for a generic batch scaling that the client can provide.
"""
def init_tasks(self, initial_batch_size: int, max_batch_size: int, min_batch_size: int, task_infos: List[Any]) -> None:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
Loads all tasks to be executed for retry and straggler detection
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
pass
def has_next_batch(self) -> bool:
"""
Returns true if there are tasks remaining in the overall List of tasks to create a new batch
"""
pass
def next_batch(self, task_info: TaskInfoObject) -> List:
"""
Gets the next batch to execute on
"""
pass
def mark_task_complete(self, task_info: TaskInfoObject) -> List:
"""
If the task has been completed, mark some field of it as true
so we know what tasks are completed and what need to be executed
"""
pass

16 changes: 16 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/exception_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import List, Optional
from ray_manager.models.ray_remote_task_exception_retry_strategy_config import RayRemoteTaskExceptionRetryConfig
def get_retry_strategy_config_for_known_exception(exception: Exception,
exception_retry_strategy_configs: List[RayRemoteTaskExceptionRetryConfig]) -> Optional[RayRemoteTaskExceptionRetryConfig]:
"""
Checks whether the exception seen is recognized as a retryable error or not
"""
for exception_retry_strategy_config in exception_retry_strategy_configs:
if type(exception) == type(exception_retry_strategy_config.exception):
return exception_retry_strategy_config

for exception_retry_strategy_config in exception_retry_strategy_configs:
if isinstance(exception, type(exception_retry_strategy_config.exception)):
return exception_retry_strategy_config

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError

class AWSSecurityTokenRateExceededException(RetryableError):

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error.failures import RetryableError

class CairnsClientException(RetryableError):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, *args: object) -> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from exceptions import Exception
class NonRetryableError(Exception):
"""
Class represents a non-retryable error
"""
def __init__(self, *args: object):
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from exceptions import Exception
class RetryableError(Exception):
"""
Class for errors that can be retried
"""
def __init__(self, *args: object) --> None:
super().__init__(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import List, Protocol
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
class ProgressNotifierInterface(Protocol):
"""
Interface for client injected progress notification system.
"""
def has_heartbeat(self, task_info: TaskInfoObject) -> bool:
"""
Sends progress of current task to parent task
"""
pass
def send_heartbeat(self, parent_task_info: TaskInfoObject) -> bool:
"""
Tells parent task if the current task has a heartbeat or not
"""
pass

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from deltacat.utils.ray_utils.retry_handler.batch_scaling_interface import BatchScalingInterface
class RayRemoteTasksBatchScalingStrategy(BatchScalingInterface):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
Default batch scaling parameters for if the client does not provide their own batch_scaling parameters
"""

def init_tasks(self)-> None:
"""
Setup AIMD scaling for the batches as the default
"""
pass

def next_batch(self) -> List:
"""
Returns the list of tasks included in the next batch of whatever size based on AIMD
"""

pass

def has_next_batch(self) -> bool:
"""
If there are no more tasks to execute that can not create a batch, return False
"""
pass
158 changes: 158 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/ray_task_submission_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations
from typing import Any, Dict, List, cast, Optional
from deltacat.utils.ray_utils.retry_handler.ray_remote_tasks_batch_scaling_strategy import RayRemoteTasksBatchScalingStrategy
import ray
import time
import logging
from deltacat.logs import configure_logger
from deltacat.utils.ray_utils.retry_handler.task_execution_error import RayRemoteTaskExecutionError
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
from deltacat.utils.ray_utils.retry_handler.retry_strategy_config import get_retry_strategy_config_for_known_exception

logger = configure_logger(logging.getLogger(__name__))

@ray.remote
def submit_single_task(taskObj: TaskInfoObject, TaskContext: Optional[Interface] = None) -> Any:
"""
Submits a single task for execution, handles any exceptions that may occur during execution,
and applies appropriate retry strategies if they are defined.
"""
try:
taskObj.attempt_count += 1
curr_attempt = taskObj.attempt_count
logger.debug(f"Executing the submitted Ray remote task as part of attempt number: {current_attempt_number}")
return taskObj.task_callable(taskObj.task_input)
except (Exception) as exception:
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, taskObj.exception_retry_strategy_configs)
if exception_retry_strategy_config is not None:
return RayRemoteTaskExecutionError(exception_retry_strategy_config.exception, taskObj)

logger.error(f"The exception thrown by submitted Ray task during attempt number: {current_attempt_number} is non-retryable or unexpected, hence throwing Non retryable exception: {exception}")
raise UnexpectedRayTaskError(str(exception))

class RayTaskSubmissionHandler:
"""
Starts execution of all given a list of Ray tasks with optional arguments: scaling strategy and straggler detection
"""
def start_tasks_execution(self,
ray_remote_task_infos: List[TaskInfoObject],
scaling_strategy: Optional[BatchScalingStrategy] = None,
straggler_detection: Optional[StragglerDetectionInterface] = None,
task_context: Optional[TaskContext]) -> None:
"""
Prepares and initiates the execution of a batch of tasks and can optionally support
custom client batch scaling, straggler detection, and task context
"""
if scaling_strategy is None:
scaling_strategy = RayRemoteTasksBatchScalingStrategy(ray_remote_task_infos)

if straggler_detection is not None:
while scaling_strategy.hasNextBatch():
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
current_batch = scaling_strategy.next_batch()
for task in current_batch:
if straggler_detection.isStraggler(task):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
ray.cancel(task)
else:
self._submit_tasks(task)

def _wait_and_get_all_task_results(self, straggler_detection: Optional[StragglerDetectionInterface]) -> List[Any]:
return self._get_task_results(self.num_of_submitted_tasks, straggler_detection)

def _get_task_results(self, num_of_results: int, straggler_detection: Optional[StragglerDetectionInterface]) -> List[Any]:
"""
Gets results from a list of tasks to be executed, and catches exceptions to manage the retry strategy.
Optional: Given a StragglerDetectionInterface, can detect and handle straggler tasks according to the client logic
"""
if not self.unfinished_promises or num_of_results == 0:
return []
elif num_of_results > len(self.unfinished_promises):
num_of_results = len(self.unfinished_promises)

finished, self.unfinished_promises = ray.wait(self.unfinished_promises, num_of_results)
successful_results = []

for finished in finished:
finished_result = None
try:
finished_result = ray.get(finished)
except (Exception) as exception:
#if exception send to method handle_ray_exception to determine what to do and assign the corresp error
finished_result = self._handle_ray_exception(exception=exception, ray_remote_task_info=self.task_promise_obj_ref_to_task_info_map[str(finished_promise)] )#evaluate the exception and return the error

if finished_result and type(finished_result) == RayRemoteTaskExecutionError:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
finished_result = cast(RayRemoteTaskExecutionError, finished_result)

if straggler_detection and straggler_detection.isStraggler(finished_result):
ray.cancel(finished_result)
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(finished_result.exception,
finished_result.ray_remote_task_info.exception_retry_strategy_configs)
if (exception_retry_strategy_config is None or finished_result.ray_remote_task_info.num_of_attempts > exception_retry_strategy_config.max_retry_attempts):
logger.error(f"The submitted task has exhausted all the maximum retries configured and finally throws exception - {finished_result.exception}")
raise finished_result.exception
self._update_ray_remote_task_options_on_exception(finished_result.exception, finished_result.ray_remote_task_info)
self.unfinished_promises.append(self._invoke_ray_remote_task(ray_remote_task_info=finished_result.ray_remote_task_info))
else:
successful_results.append(finished_result)
del self.task_promise_obj_ref_to_task_info_map[str(finished_promise)]

num_of_successful_results = len(successful_results)
self.num_of_submitted_tasks_completed += num_of_successful_results
self.current_batch_size -= num_of_successful_results

self._enqueue_new_tasks(num_of_successful_results)

if num_of_successful_results < num_of_results:
successful_results.extend(self._get_task_results(num_of_results - num_of_successful_results))
return successful_results
else:
return successful_results


def _enqueue_new_tasks(self, num_of_tasks: int) -> None:
"""
Helper method to submit a specified number of tasks
"""
new_tasks_submitted = self.remaining_ray_remote_task_infos[:num_of_tasks]
num_of_new_tasks_submitted = len(new_tasks_submitted)
self._submit_tasks(new_tasks_submitted)
self.remaining_ray_remote_task_infos = self.remaining_ray_remote_task_infos[num_of_tasks:]
self.current_batch_size += num_of_new_tasks_submitted
logger.info(f"Enqueued {num_of_new_tasks_submitted} new tasks. Current concurrency of tasks execution: {self.current_batch_size}, Current Task progress: {self.num_of_submitted_tasks_completed}/{self.num_of_submitted_tasks}")

def _submit_tasks(self, ray_remote_task_infos: List[RayRemoteTaskInfo]) -> None:
for ray_remote_task_info in ray_remote_task_infos:
time.sleep(0.005)
if self.straggler_detection and self.straggler_detection.is_straggler(ray_remote_task_info):
ray.cancel(ray_remote_task_info)
else:
self.unfinished_promises.append(self._invoke_ray_remote_task(ray_remote_task_info))
#replace with ray.options
def _invoke_ray_remote_task(self, ray_remote_task_info: RayRemoteTaskInfo) -> Any:
ray_remote_task_options_arguments = dict()

if ray_remote_task_info.ray_remote_task_options.memory:
ray_remote_task_options_arguments['memory'] = ray_remote_task_info.ray_remote_task_options.memory

if ray_remote_task_info.ray_remote_task_options.num_cpus:
ray_remote_task_options_arguments['num_cpus'] = ray_remote_task_info.ray_remote_task_options.num_cpus

if ray_remote_task_info.ray_remote_task_options.placement_group:
ray_remote_task_options_arguments['placement_group'] = ray_remote_task_info.ray_remote_task_options.placement_group

ray_remote_task_promise_obj_ref = submit_single_task.options(**ray_remote_task_options_arguments).remote(ray_remote_task_info=ray_remote_task_info)
self.task_promise_obj_ref_to_task_info_map[str(ray_remote_task_promise_obj_ref)] = ray_remote_task_info

return ray_remote_task_promise_obj_ref

#replace with ray.options
def _update_ray_remote_task_options_on_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo):
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, ray_remote_task_info.exception_retry_strategy_configs)
if exception_retry_strategy_config and ray_remote_task_info.ray_remote_task_options.memory:
logger.info(f"Updating the Ray remote task options after encountering exception: {exception}")
ray_remote_task_memory_multiply_factor = exception_retry_strategy_config.ray_remote_task_memory_multiply_factor
ray_remote_task_info.ray_remote_task_options.memory *= ray_remote_task_memory_multiply_factor
logger.info(f"Updated ray remote task options Memory: {ray_remote_task_info.ray_remote_task_options.memory}")
#replace with own exceptions
def _handle_ray_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo) -> RayRemoteTaskExecutionError:
logger.error(f"Ray remote task failed with {type(exception)} Ray exception: {exception}")
if type(exception).__name__ == "AWSSecurityTokenRateExceededException(RetryableError)"
28 changes: 28 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/retry_task_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List, Protocol
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
import Exception

class RetryTaskInterface(Protocol):
def init_tasks(self, task_infos: List[TaskInfoObject]) -> None:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
Loads all tasks to check for retries if exception occurs
"""
pass

def should_retry(self, task: TaskInfoObject, exception: Exception) -> bool:
"""
Given a task, determine whether it should be retried or not
"""
pass

def get_wait_time(self, task: TaskInfoObject) -> int:
"""
Determines the wait time between retries
"""
pass

def retry(self, task: TaskInfoObject, exception: Exception) -> None:
"""
Executes retry behavior for the given exception
"""
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any, Protocol
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
from deltacat.utils.ray_utils.retry_handler.task_context import TaskContext
class StragglerDetectionInterface(Protocol):
"""
Using TaskContext, handles the client-side implementation for straggler detection
"""
def is_straggler(self, task: TaskInfoObject, task_context: TaskContext) -> bool:
"""
Given all the info, returns whether this specific task is a straggler or not
"""
pass
11 changes: 11 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/task_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataclasses import dataclass
from deltacat.utils.ray_utils.retry_handler.progress_notifier_interface import ProgressNotifierInterface
@dataclass
class TaskContext():
"""
This class represents important info pertaining to the task that other interfaces like Straggler Detection
can use to make decisions
"""
def __init__(self, progress_notifier: ProgressNotifierInterface, timeoutTime: float):
self.progress_notifier = progress_notifier
self.timeoutTime = timeoutTime
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List
from deltacat.utils.ray_utils.retry_handler.task_constants import DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BATCH_SIZE_MULTIPLICATIVE_DECREASE_FACTOR, DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BACK_OFF_IN_MS, DEFAULT_RAY_REMOTE_TASK_BATCH_POSITIVE_FEEDBACK_BATCH_SIZE_ADDITIVE_INCREASE
class TaskExceptionRetryConfig:
"""
Determines how to handle and retry specific exceptions during task executions
"""
def __init__(self, exception: Exception,
max_retry_attempts: int = DEFAULT_MAX_RAY_REMOTE_TASK_RETRY_ATTEMPTS,
initial_back_off_in_ms: int = DEFAULT_RAY_REMOTE_TASK_RETRY_INITIAL_BACK_OFF_IN_MS,
back_off_factor: int = DEFAULT_RAY_REMOTE_TASK_RETRY_BACK_OFF_FACTOR,
ray_remote_task_memory_multiplication_factor: float = DEFAULT_RAY_REMOTE_TASK_MEMORY_MULTIPLICATION_FACTOR,
is_throttling_exception: bool = False) -> None:
self.exception = exception
self.max_retry_attempts = max_retry_attempts
self.initial_back_off_in_ms = initial_back_off_in_ms
self.back_off_factor = back_off_factor
self.ray_remote_task_memory_multiply_factor = ray_remote_task_memory_multiplication_factor
self.is_throttling_exception = is_throttling_exception
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class RayRemoteTaskExecutionError:
"""
An error class that denotes the Ray Remote Task Execution Failure
"""
def __init__(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo) -> None:
self.exception = exception
self.ray_remote_task_info = ray_remote_task_info
Loading