diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 27b8b53..5ed04a1 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -6,21 +6,36 @@ import json import threading -from typing import Any, Dict, List +from typing import Any, Dict from .utils.dynamic import instantiate_class_type _THREAD_LOCAL = threading.local() -def get_request_provider_data() -> Any: - return getattr(_THREAD_LOCAL, "provider_data", None) +class NeedsRequestProviderData: + def get_request_provider_data(self) -> Any: + spec = self.__provider_spec__ + assert spec, f"Provider spec not set on {self.__class__}" + provider_id = spec.provider_id + validator_class = spec.provider_data_validator + if not validator_class: + raise ValueError(f"Provider {provider_id} does not have a validator") -def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]): - if not validator_classes: - return + val = _THREAD_LOCAL.provider_data_header_value + if not val: + return None + + validator = instantiate_class_type(validator_class) + try: + provider_data = validator(**val) + return provider_data + except Exception as e: + print("Error parsing provider data", e) + +def set_request_provider_data(headers: Dict[str, str]): keys = [ "X-LlamaStack-ProviderData", "x-llamastack-providerdata", @@ -39,12 +54,4 @@ def set_request_provider_data(headers: Dict[str, str], validator_classes: List[s print("Provider data not encoded as a JSON object!", val) return - for validator_class in validator_classes: - validator = instantiate_class_type(validator_class) - try: - provider_data = validator(**val) - if provider_data: - _THREAD_LOCAL.provider_data = provider_data - return - except Exception as e: - print("Error parsing provider data", e) + _THREAD_LOCAL.provider_data_header_value = val diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index a32c470..9cebe9b 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -207,9 +207,7 @@ async def endpoint(request: Request): return endpoint -def create_dynamic_typed_route( - func: Any, method: str, provider_data_validators: List[str] -): +def create_dynamic_typed_route(func: Any, method: str): hints = get_type_hints(func) response_model = hints.get("return") @@ -224,7 +222,7 @@ def create_dynamic_typed_route( async def endpoint(request: Request, **kwargs): await start_trace(func.__name__) - set_request_provider_data(request.headers, provider_data_validators) + set_request_provider_data(request.headers) async def sse_generator(event_gen): try: @@ -255,7 +253,7 @@ async def sse_generator(event_gen): async def endpoint(request: Request, **kwargs): await start_trace(func.__name__) - set_request_provider_data(request.headers, provider_data_validators) + set_request_provider_data(request.headers) try: return ( @@ -462,21 +460,10 @@ async def healthcheck(): impl_method = getattr(impl, endpoint.name) - validators = [] - if isinstance(provider_spec, AutoRoutedProviderSpec): - inner_specs = specs[provider_spec.routing_table_api].inner_specs - for spec in inner_specs: - if spec.provider_data_validator: - validators.append(spec.provider_data_validator) - elif not isinstance(provider_spec, RoutingTableProviderSpec): - if provider_spec.provider_data_validator: - validators.append(provider_spec.provider_data_validator) - getattr(app, endpoint.method)(endpoint.route, response_model=None)( create_dynamic_typed_route( impl_method, endpoint.method, - validators, ) ) diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 0737868..7053834 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -15,7 +15,7 @@ from together import Together from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.distribution.request_headers import get_request_provider_data +from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) @@ -32,7 +32,7 @@ } -class TogetherInferenceAdapter(Inference): +class TogetherInferenceAdapter(Inference, NeedsRequestProviderData): def __init__(self, config: TogetherImplConfig) -> None: self.config = config tokenizer = Tokenizer.get_instance() @@ -103,7 +103,7 @@ async def chat_completion( ) -> AsyncGenerator: together_api_key = None - provider_data = get_request_provider_data() + provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.together_api_key: raise ValueError( 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index 8e552fb..24fcc63 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -13,7 +13,7 @@ SafetyViolation, ViolationLevel, ) -from llama_stack.distribution.request_headers import get_request_provider_data +from llama_stack.distribution.request_headers import NeedsRequestProviderData from .config import TogetherSafetyConfig @@ -40,7 +40,7 @@ def shield_type_to_model_name(shield_type: str) -> str: return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True)) -class TogetherSafetyImpl(Safety): +class TogetherSafetyImpl(Safety, NeedsRequestProviderData): def __init__(self, config: TogetherSafetyConfig) -> None: self.config = config @@ -52,7 +52,7 @@ async def run_shield( ) -> RunShieldResponse: together_api_key = None - provider_data = get_request_provider_data() + provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.together_api_key: raise ValueError( 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }'