Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract provider data properly (attempt 2) #148

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions llama_stack/distribution/request_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,36 @@

import json
import threading
from typing import Any, Dict, List
from typing import Any, Dict

from .utils.dynamic import instantiate_class_type

_THREAD_LOCAL = threading.local()


def get_request_provider_data() -> Any:
return getattr(_THREAD_LOCAL, "provider_data", None)
class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__
assert spec, f"Provider spec not set on {self.__class__}"

provider_id = spec.provider_id
validator_class = spec.provider_data_validator
if not validator_class:
raise ValueError(f"Provider {provider_id} does not have a validator")

def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]):
if not validator_classes:
return
val = _THREAD_LOCAL.provider_data_header_value
if not val:
return None

validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
return provider_data
except Exception as e:
print("Error parsing provider data", e)


def set_request_provider_data(headers: Dict[str, str]):
keys = [
"X-LlamaStack-ProviderData",
"x-llamastack-providerdata",
Expand All @@ -39,12 +54,4 @@ def set_request_provider_data(headers: Dict[str, str], validator_classes: List[s
print("Provider data not encoded as a JSON object!", val)
return

for validator_class in validator_classes:
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
if provider_data:
_THREAD_LOCAL.provider_data = provider_data
return
except Exception as e:
print("Error parsing provider data", e)
_THREAD_LOCAL.provider_data_header_value = val
19 changes: 3 additions & 16 deletions llama_stack/distribution/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,7 @@ async def endpoint(request: Request):
return endpoint


def create_dynamic_typed_route(
func: Any, method: str, provider_data_validators: List[str]
):
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
response_model = hints.get("return")

Expand All @@ -224,7 +222,7 @@ def create_dynamic_typed_route(
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)

set_request_provider_data(request.headers, provider_data_validators)
set_request_provider_data(request.headers)

async def sse_generator(event_gen):
try:
Expand Down Expand Up @@ -255,7 +253,7 @@ async def sse_generator(event_gen):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)

set_request_provider_data(request.headers, provider_data_validators)
set_request_provider_data(request.headers)

try:
return (
Expand Down Expand Up @@ -462,21 +460,10 @@ async def healthcheck():

impl_method = getattr(impl, endpoint.name)

validators = []
if isinstance(provider_spec, AutoRoutedProviderSpec):
inner_specs = specs[provider_spec.routing_table_api].inner_specs
for spec in inner_specs:
if spec.provider_data_validator:
validators.append(spec.provider_data_validator)
elif not isinstance(provider_spec, RoutingTableProviderSpec):
if provider_spec.provider_data_validator:
validators.append(provider_spec.provider_data_validator)

getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
validators,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from together import Together

from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import get_request_provider_data
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
Expand All @@ -32,7 +32,7 @@
}


class TogetherInferenceAdapter(Inference):
class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
self.config = config
tokenizer = Tokenizer.get_instance()
Expand Down Expand Up @@ -103,7 +103,7 @@ async def chat_completion(
) -> AsyncGenerator:

together_api_key = None
provider_data = get_request_provider_data()
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/providers/adapters/safety/together/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
SafetyViolation,
ViolationLevel,
)
from llama_stack.distribution.request_headers import get_request_provider_data
from llama_stack.distribution.request_headers import NeedsRequestProviderData

from .config import TogetherSafetyConfig

Expand All @@ -40,7 +40,7 @@ def shield_type_to_model_name(shield_type: str) -> str:
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))


class TogetherSafetyImpl(Safety):
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config

Expand All @@ -52,7 +52,7 @@ async def run_shield(
) -> RunShieldResponse:

together_api_key = None
provider_data = get_request_provider_data()
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
Expand Down