From d33626cebb0d5228774471350ae6489a34849bbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 25 Jul 2024 10:28:19 +0200 Subject: [PATCH 1/9] get rid of upstream code, refactor for vLLM==0.5.4 --- src/vllm_tgis_adapter/__main__.py | 359 +++------------------- src/vllm_tgis_adapter/grpc/grpc_server.py | 61 ++-- src/vllm_tgis_adapter/http.py | 69 +++++ tests/conftest.py | 166 ++++------ tests/test_grpc_server.py | 13 +- tests/test_http_server.py | 6 +- tests/utils.py | 18 +- 7 files changed, 223 insertions(+), 469 deletions(-) create mode 100644 src/vllm_tgis_adapter/http.py diff --git a/src/vllm_tgis_adapter/__main__.py b/src/vllm_tgis_adapter/__main__.py index 7841c74..87f00c8 100644 --- a/src/vllm_tgis_adapter/__main__.py +++ b/src/vllm_tgis_adapter/__main__.py @@ -1,313 +1,67 @@ from __future__ import annotations import asyncio -import importlib -import inspect -import re import signal -from contextlib import asynccontextmanager -from http import HTTPStatus from typing import TYPE_CHECKING -import fastapi +import uvloop import vllm -from fastapi import APIRouter -from fastapi.exceptions import RequestValidationError -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, Response, StreamingResponse -from prometheus_client import make_asgi_app -from starlette.routing import Mount -from uvicorn import Config as UvicornConfig -from uvicorn import Server as UvicornServer -from vllm import envs -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.entrypoints.openai.protocol import ( # noqa: TCH002 # pydantic needs to access these annotations - ChatCompletionRequest, - ChatCompletionResponse, - CompletionRequest, - DetokenizeRequest, - DetokenizeResponse, - EmbeddingRequest, - ErrorResponse, - TokenizeRequest, - TokenizeResponse, -) -from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_tokenization import ( - OpenAIServingTokenization, +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client, ) -from vllm.usage.usage_lib import UsageContext +from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils import FlexibleArgumentParser from .grpc import run_grpc_server +from .http import run_http_server from .logging import init_logger from .tgis_utils.args import EnvVarArgumentParser, add_tgis_args, postprocess_tgis_args if TYPE_CHECKING: import argparse - from collections.abc import AsyncGenerator - - from vllm.config import ModelConfig - - -TIMEOUT_KEEP_ALIVE = 5 # seconds - -openai_serving_chat: OpenAIServingChat -openai_serving_completion: OpenAIServingCompletion -openai_serving_embedding: OpenAIServingEmbedding -openai_serving_tokenization: OpenAIServingTokenization - -logger = init_logger(__name__) - -_running_tasks: set[asyncio.Task] = set() +logger = init_logger("vllm-tgis-adapter") -router = APIRouter() +async def start_servers(args: argparse.Namespace) -> None: + loop = asyncio.get_running_loop() -def mount_metrics(app: fastapi.FastAPI) -> None: - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app()) - # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile("^/metrics(?P.*)$") - app.routes.append(metrics_route) - - -@router.get("/health") -async def health() -> Response: - """Health check.""" - await openai_serving_chat.engine.check_health() - return Response(status_code=200) - - -@router.post("/tokenize") -async def tokenize(request: TokenizeRequest) -> JSONResponse: - generator = await openai_serving_tokenization.create_tokenize(request) - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), - status_code=generator.code, + tasks: list[asyncio.Task] = [] + async with build_async_engine_client(args) as engine: + http_server_task = loop.create_task( + run_http_server(args, engine), + name="http_server", ) - assert isinstance(generator, TokenizeResponse) - return JSONResponse(content=generator.model_dump()) - - -@router.post("/detokenize") -async def detokenize(request: DetokenizeRequest) -> JSONResponse: - generator = await openai_serving_tokenization.create_detokenize(request) - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), - status_code=generator.code, - ) - - assert isinstance(generator, DetokenizeResponse) - return JSONResponse(content=generator.model_dump()) - - -@router.get("/v1/models") -async def show_available_models() -> JSONResponse: - models = await openai_serving_completion.show_available_models() - return JSONResponse(content=models.model_dump()) - - -@router.get("/version") -async def show_version() -> JSONResponse: - ver = {"version": vllm.__version__, "commit": vllm.__commit__} - return JSONResponse(content=ver) - - -@router.post("/v1/chat/completions") -async def create_chat_completion( - request: ChatCompletionRequest, - raw_request: fastapi.Request, -) -> JSONResponse: - generator = await openai_serving_chat.create_chat_completion( - request, - raw_request, - ) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), status_code=generator.code) - if request.stream: - return StreamingResponse(content=generator, media_type="text/event-stream") - - assert isinstance(generator, ChatCompletionResponse) - return JSONResponse(content=generator.model_dump()) - - -@router.post("/v1/completions") -async def create_completion(request: CompletionRequest, raw_request: fastapi.Request): # noqa: ANN201 - generator = await openai_serving_completion.create_completion(request, raw_request) - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), - status_code=generator.code, - ) - if request.stream: - return StreamingResponse(content=generator, media_type="text/event-stream") - return JSONResponse(content=generator.model_dump()) - - -@router.post("/v1/embeddings") -async def create_embedding(request: EmbeddingRequest, raw_request: fastapi.Request): # noqa: ANN201 - generator = await openai_serving_embedding.create_embedding(request, raw_request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), status_code=generator.code) - return JSONResponse(content=generator.model_dump()) - - -def build_app( # noqa: C901 - engine: AsyncLLMEngine, args: argparse.Namespace -) -> fastapi.FastAPI: - @asynccontextmanager - async def lifespan(app: fastapi.FastAPI) -> AsyncGenerator: # noqa: ARG001 - async def _force_log(): # noqa: ANN202 - while True: - await asyncio.sleep(10) - await engine.do_log_stats() - - if not args.disable_log_stats: - task = asyncio.create_task(_force_log()) - _running_tasks.add(task) - task.add_done_callback(_running_tasks.remove) + tasks.append(http_server_task) - yield - - app = fastapi.FastAPI(lifespan=lifespan) - app.include_router(router) - app.root_path = args.root_path - - mount_metrics(app) - - app.add_middleware( - CORSMiddleware, - allow_origins=args.allowed_origins, - allow_credentials=args.allow_credentials, - allow_methods=args.allowed_methods, - allow_headers=args.allowed_headers, - ) - - @app.exception_handler(RequestValidationError) - async def validation_exception_handler(_, exc): # noqa: ANN001, ANN202 - err = openai_serving_chat.create_error_response(message=str(exc)) - return JSONResponse( - err.model_dump(), - status_code=HTTPStatus.BAD_REQUEST, + grpc_server_task = loop.create_task( + run_grpc_server(args, engine), + name="grpc_server", ) + tasks.append(grpc_server_task) - if token := envs.VLLM_API_KEY or args.api_key: - - @app.middleware("http") - async def authentication(request: fastapi.Request, call_next): # noqa: ANN001, ANN202 - root_path = "" if args.root_path is None else args.root_path - if request.method == "OPTIONS": - return await call_next(request) - if not request.url.path.startswith(f"{root_path}/v1"): - return await call_next(request) - if request.headers.get("Authorization") != "Bearer " + token: - return JSONResponse( - content={"error": "Unauthorized"}, - status_code=401, - ) - return await call_next(request) - - for middleware in args.middleware: - module_path, object_name = middleware.rsplit(".", 1) - imported = getattr(importlib.import_module(module_path), object_name) - if inspect.isclass(imported): - app.add_middleware(imported) - elif inspect.iscoroutinefunction(imported): - app.middleware("http")(imported) - else: - raise ValueError( - f"Invalid middleware {middleware}. " f"Must be a function or a class." - ) - - return app - + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + for task in tasks: + task.cancel() -async def run_http_server( - engine: AsyncLLMEngine, - args: argparse.Namespace, - model_config: ModelConfig, -) -> None: - app = build_app(engine, args) + async def override_signal_handler() -> None: + loop = asyncio.get_running_loop() - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, signal_handler) - if args.disable_log_requests: - request_logger = None - else: - request_logger = RequestLogger(max_log_len=args.max_log_len) + tasks.append(loop.create_task(override_signal_handler())) - global openai_serving_chat # noqa: PLW0603 - global openai_serving_completion # noqa: PLW0603 - global openai_serving_embedding # noqa: PLW0603 - global openai_serving_tokenization # noqa: PLW0603 - - openai_serving_chat = OpenAIServingChat( - engine, - model_config, - served_model_names, - args.response_role, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, - request_logger=request_logger, - chat_template=args.chat_template, - ) - openai_serving_completion = OpenAIServingCompletion( - engine, - model_config, - served_model_names, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, - request_logger=request_logger, - ) - openai_serving_embedding = OpenAIServingEmbedding( - engine, - model_config, - served_model_names, - request_logger=request_logger, - ) - openai_serving_tokenization = OpenAIServingTokenization( - engine, - model_config, - served_model_names, - lora_modules=args.lora_modules, - request_logger=request_logger, - chat_template=args.chat_template, - ) - app.root_path = args.root_path - config = UvicornConfig( - app, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - ) - - server = UvicornServer(config) - try: - await server.serve() - except asyncio.CancelledError: - print("Gracefully stopping http server") # noqa: T201 - await server.shutdown() + try: + await asyncio.wait(tasks) + except asyncio.CancelledError: + for task in tasks: + task.cancel() if __name__ == "__main__": - parser = FlexibleArgumentParser("vLLM TGIS GRPC + OpenAI Rest api server") + parser = FlexibleArgumentParser("vLLM TGIS GRPC + OpenAI REST api server") # convert to our custom env var arg parser parser = EnvVarArgumentParser(parser=make_arg_parser(parser)) parser = add_tgis_args(parser) @@ -322,47 +76,6 @@ async def run_http_server( logger.info("vLLM version %s", version_info) logger.info("args: %s", args) - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args( - engine_args, # type: ignore[arg-type] - usage_context=UsageContext.OPENAI_API_SERVER, - ) - - event_loop: asyncio.AbstractEventLoop | None - try: - event_loop = asyncio.get_running_loop() - except RuntimeError: - event_loop = None - - if event_loop is not None and event_loop.is_running(): - # If the current is instanced by Ray Serve, - # there is already a running event loop - model_config = event_loop.run_until_complete(engine.get_model_config()) - else: - # When using single vLLM without engine_use_ray - model_config = asyncio.run(engine.get_model_config()) - - if event_loop is None: - event_loop = asyncio.new_event_loop() - - async def run() -> None: - loop = asyncio.get_running_loop() - - http_server_task = loop.create_task(run_http_server(engine, args, model_config)) - grpc_server_task = loop.create_task( - run_grpc_server( - engine, args, disable_log_stats=engine_args.disable_log_stats - ) - ) - - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - grpc_server_task.cancel() - http_server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - await asyncio.gather(grpc_server_task, http_server_task) - - event_loop.run_until_complete(run()) + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + loop = asyncio.new_event_loop() + loop.run_until_complete(start_servers(args)) diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index a461559..bd1e9c8 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -17,8 +17,8 @@ from grpc._cython.cygrpc import AbortError from grpc_health.v1 import health, health_pb2, health_pb2_grpc from grpc_reflection.v1alpha import reflection -from vllm import AsyncLLMEngine, SamplingParams -from vllm.engine.async_llm_engine import _AsyncLLMEngine +from vllm import SamplingParams +from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.serving_completion import merge_async_iterators from vllm.inputs import LLMInputs from vllm.tracing import ( @@ -72,6 +72,7 @@ from transformers import PreTrainedTokenizer from vllm import CompletionOutput, RequestOutput from vllm.config import ModelConfig + from vllm.engine.protocol import AsyncEngineClient from vllm.lora.request import LoRARequest from vllm.sequence import Logprob @@ -84,18 +85,13 @@ SingleGenerationRequest, ) - try: - from .adapters import PromptAdapterRequest - except ImportError: - pass - _T = TypeVar("_T") _F = TypeVar("_F", Callable, Coroutine) logger = init_logger(__name__) service_metrics = ServiceMetrics() -ADD_SPECIAL_TOKENS = os.getenv("ADD_SPECIAL_TOKENS") +ADD_SPECIAL_TOKENS: str | None = os.getenv("ADD_SPECIAL_TOKENS") if ADD_SPECIAL_TOKENS is not None: ADD_SPECIAL_TOKENS = ADD_SPECIAL_TOKENS.lower() not in (0, "false") @@ -173,11 +169,11 @@ class TextGenerationService(generation_pb2_grpc.GenerationServiceServicer): def __init__( self, - engine: AsyncLLMEngine, + engine: AsyncEngineClient | AsyncLLMEngine, args: argparse.Namespace, health_servicer: health.HealthServicer, ): - self.engine: AsyncLLMEngine = engine + self.engine: AsyncEngineClient = engine # This is set in post_init() self.config: ModelConfig | None = None @@ -201,12 +197,18 @@ def __init__( async def post_init(self) -> None: self.config = await self.engine.get_model_config() - # Swap in the special TGIS stats logger - tgis_stats_logger = TGISStatLogger( - vllm_stat_logger=self.engine.engine.stat_loggers["prometheus"], - max_sequence_len=self.config.max_model_len, - ) - self.engine.engine.stat_loggers["prometheus"] = tgis_stats_logger + if not isinstance(self.engine, AsyncLLMEngine): + logger.warning( + "TGIS Metrics currently disabled in decoupled front-end mode, " + "set DISABLE_FRONTEND_MULTIPROCESSING=True to enable" + ) + else: + # Swap in the special TGIS stats logger + tgis_stats_logger = TGISStatLogger( + vllm_stat_logger=self.engine.engine.stat_loggers["prometheus"], + max_sequence_len=self.config.max_model_len, + ) + self.engine.engine.stat_loggers["prometheus"] = tgis_stats_logger self.health_servicer.set( self.SERVICE_NAME, @@ -876,19 +878,9 @@ async def ModelInfo( async def start_grpc_server( - engine: AsyncLLMEngine, args: argparse.Namespace + args: argparse.Namespace, + engine: AsyncLLMEngine | AsyncEngineClient, ) -> aio.Server: - # Log memory summary after model is loaded - from torch.cuda import memory_summary - - assert isinstance(engine, AsyncLLMEngine) - assert isinstance(engine.engine, _AsyncLLMEngine) - - if (device_type := engine.engine.device_config.device.type) == "cuda": - logger.info(memory_summary(engine.engine.device_config.device)) - else: - logger.warning("Cannot print device usage for device type: %s", device_type) - server = aio.server() health_servicer = health.HealthServicer() @@ -951,20 +943,17 @@ async def start_grpc_server( async def run_grpc_server( - engine: AsyncLLMEngine, args: argparse.Namespace, - *, - disable_log_stats: bool, + engine: AsyncEngineClient | AsyncLLMEngine, ) -> None: - assert args is not None - - server = await start_grpc_server(engine, args) + server = await start_grpc_server( + args, + engine, + ) try: while True: await asyncio.sleep(10) - if not disable_log_stats: - await engine.do_log_stats() except asyncio.CancelledError: print("Gracefully stopping gRPC server") # noqa: T201 await server.stop(30) # TODO configurable grace diff --git a/src/vllm_tgis_adapter/http.py b/src/vllm_tgis_adapter/http.py new file mode 100644 index 0000000..85df9f8 --- /dev/null +++ b/src/vllm_tgis_adapter/http.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +import uvicorn +from vllm.entrypoints.openai.api_server import ( + init_app, +) +from vllm.logger import init_logger + +if TYPE_CHECKING: + import argparse + + from fastapi import FastAPI + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.engine.protocol import AsyncEngineClient + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +logger = init_logger(__name__) + + +async def serve_http( + app: FastAPI, + **uvicorn_kwargs, # noqa: ANN003 +) -> None: + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ", ".join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + server = uvicorn.Server(config) + + try: + await server.serve() + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + await server.shutdown() + + +async def run_http_server( + args: argparse.Namespace, + engine: AsyncLLMEngine | AsyncEngineClient, + **uvicorn_kwargs, # noqa: ANN003 +) -> None: + # modified copy of vllm.entrypoints.openai.api_server.run_server that + # allows passing of the engine + + app = await init_app(engine, args) # type: ignore[arg-type] + + await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 96b16ac..9b7d5bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,19 +3,15 @@ import asyncio import sys import threading -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Annotated, TypeVar import pytest import requests import vllm -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser -from vllm_tgis_adapter.__main__ import run_http_server -from vllm_tgis_adapter.grpc import run_grpc_server +from vllm_tgis_adapter.__main__ import start_servers from vllm_tgis_adapter.grpc.grpc_server import TextGenerationService from vllm_tgis_adapter.healthcheck import health_check from vllm_tgis_adapter.tgis_utils.args import ( @@ -24,36 +20,34 @@ postprocess_tgis_args, ) -from .utils import get_random_port, wait_until +from .utils import TaskFailedError, get_random_port, wait_until if TYPE_CHECKING: import argparse + from collections.abc import Generator - from vllm.config import ModelConfig + T = TypeVar("T") + YieldFixture = Generator[T, None, None] + ArgFixture = Annotated[T, pytest.fixture] -@pytest.fixture(scope="session") -def monkeysession(): - with pytest.MonkeyPatch.context() as mp: - yield mp - -@pytest.fixture(scope="session") +@pytest.fixture() def lora_available() -> bool: # lora does not work on cpu return not vllm.config.is_cpu() -@pytest.fixture(scope="session") -def lora_adapter_name(request: pytest.FixtureRequest): +@pytest.fixture() +def lora_adapter_name(request: pytest.FixtureRequest) -> str: if not request.getfixturevalue("lora_available"): pytest.skip("Lora is not available with this configuration") return "lora-test" -@pytest.fixture(scope="session") -def lora_adapter_path(request: pytest.FixtureRequest): +@pytest.fixture() +def lora_adapter_path(request: pytest.FixtureRequest) -> str: if not request.getfixturevalue("lora_available"): pytest.skip("Lora is not available with this configuration") @@ -63,13 +57,13 @@ def lora_adapter_path(request: pytest.FixtureRequest): return path -@pytest.fixture(scope="session") +@pytest.fixture() def args( request: pytest.FixtureRequest, - monkeysession, - grpc_server_thread_port, - http_server_thread_port, - lora_available, + monkeypatch, + grpc_server_port: ArgFixture[int], + http_server_port: ArgFixture[int], + lora_available: ArgFixture[bool], ) -> argparse.Namespace: """Return parsed CLI arguments for the adapter/vLLM.""" # avoid parsing pytest arguments as vllm/vllm_tgis_adapter arguments @@ -81,13 +75,13 @@ def args( extra_args.extend(("--enable-lora", f"--lora-modules={name}={path}")) - monkeysession.setattr( + monkeypatch.setattr( sys, "argv", [ "__main__.py", - f"--grpc-port={grpc_server_thread_port}", - f"--port={http_server_thread_port}", + f"--grpc-port={grpc_server_port}", + f"--port={http_server_port}", *extra_args, ], ) @@ -100,104 +94,78 @@ def args( return args -@pytest.fixture(scope="session") -def engine_args(args) -> AsyncEngineArgs: - """Return AsyncEngineArgs from cli args.""" - return AsyncEngineArgs.from_cli_args(args) - - -@pytest.fixture(scope="session") -def engine(engine_args) -> AsyncLLMEngine: - """Return a vLLM engine from the engine args.""" - engine = AsyncLLMEngine.from_engine_args( - engine_args, # type: ignore[arg-type] - usage_context=UsageContext.OPENAI_API_SERVER, - ) - return engine - - -@pytest.fixture(scope="session") -def model_config(engine) -> ModelConfig: - """Return a vLLM ModelConfig.""" - return asyncio.run(engine.get_model_config()) - - -@pytest.fixture(scope="session") -def grpc_server_thread_port() -> int: +@pytest.fixture() +def grpc_server_port() -> int: """Port for the grpc server.""" return get_random_port() -@pytest.fixture(scope="session") -def grpc_server_url(grpc_server_thread_port) -> str: - """Url for the grpc server.""" - return f"localhost:{grpc_server_thread_port}" - - -@pytest.fixture(scope="session") -def _grpc_server(engine, args, grpc_server_url) -> None: - """Spins up the grpc server in a background thread.""" - - def _health_check(): - assert health_check( - server_url=grpc_server_url, - insecure=True, - timeout=1, - service=TextGenerationService.SERVICE_NAME, - ) +@pytest.fixture() +def grpc_server_address(grpc_server_port: ArgFixture[int]) -> str: + """Address for the grpc server.""" + return f"localhost:{grpc_server_port}" - loop = asyncio.new_event_loop() - task: asyncio.Task | None = None - def target(): - nonlocal task +@pytest.fixture() +def http_server_port() -> int: + """Port for the http server.""" + return get_random_port() - task = loop.create_task(run_grpc_server(engine, args, disable_log_stats=False)) - loop.run_until_complete(task) - t = threading.Thread(target=target) - t.start() +@pytest.fixture() +def http_server_url(http_server_port: ArgFixture[int]) -> str: + """Url for the http server.""" + return f"http://localhost:{http_server_port}" - try: - wait_until(_health_check) - yield - finally: - task.cancel() - t.join() +@pytest.fixture() +def _servers( + args: ArgFixture[argparse.Namespace], + grpc_server_address: ArgFixture[str], + http_server_url: ArgFixture[str], + monkeypatch, +) -> YieldFixture[None]: + """Run the servers in an asyncio loop in a background thread.""" + global server # noqa: PLW0602 -@pytest.fixture(scope="session") -def http_server_thread_port(scope="session") -> int: - """Port for the http server.""" - return get_random_port() + loop = asyncio.new_event_loop() + task: asyncio.Task | None = None -@pytest.fixture(scope="session") -def http_server_url(http_server_thread_port) -> str: - """Url for the http server.""" - return f"http://localhost:{http_server_thread_port}" + def _health_check() -> None: + if not task: + raise TaskFailedError + if task.done(): + exc = task.exception() + if exc: + raise TaskFailedError from exc -@pytest.fixture(scope="session") -def _http_server(engine, model_config, engine_args, args, http_server_url) -> None: - """Spins up the http server in a background thread.""" + raise TaskFailedError - def _health_check() -> None: requests.get( f"{http_server_url}/health", timeout=1, ).raise_for_status() - global server # noqa: PLW0602 + assert health_check( + server_url=grpc_server_address, + insecure=True, + timeout=1, + service=TextGenerationService.SERVICE_NAME, + ) - loop = asyncio.new_event_loop() + # patch the add_signal_handler method so that instantiating the servers + # does not try to modify signal handlers in a child thread, which cannot be done + def dummy_signal_handler(*args, **kwargs): + pass - task: asyncio.Task | None = None + monkeypatch.setattr(loop, "add_signal_handler", dummy_signal_handler) def target(): nonlocal task - task = loop.create_task(run_http_server(engine, args, model_config)) + task = loop.create_task(start_servers(args)) loop.run_until_complete(task) t = threading.Thread(target=target) @@ -207,5 +175,7 @@ def target(): wait_until(_health_check) yield finally: - task.cancel() + if task: + task.cancel() + t.join() diff --git a/tests/test_grpc_server.py b/tests/test_grpc_server.py index 7a6d401..6d4e9fe 100644 --- a/tests/test_grpc_server.py +++ b/tests/test_grpc_server.py @@ -4,17 +4,18 @@ @pytest.fixture() -def grpc_client(grpc_server_thread_port, _grpc_server): +def grpc_client(grpc_server_address, _servers): """Return a grpc client connected to the grpc server.""" + host, port = grpc_server_address.split(":") with GrpcClient( - host="localhost", - port=grpc_server_thread_port, + host=host, + port=port, insecure=True, ) as client: yield client -def test_generation_request(grpc_client, grpc_server_thread_port): +def test_generation_request(grpc_client): response = grpc_client.make_request( "The answer to life the universe and everything is " ) @@ -24,7 +25,7 @@ def test_generation_request(grpc_client, grpc_server_thread_port): assert response.stop_reason is not None -def test_generation_request_stream(grpc_client, grpc_server_thread_port): +def test_generation_request_stream(grpc_client): streaming_response = grpc_client.make_request_stream( "The answer to life the universe and everything is ", ) @@ -36,7 +37,7 @@ def test_generation_request_stream(grpc_client, grpc_server_thread_port): assert "".join(text_chunks) -def test_batched_generation_request(grpc_client, grpc_server_thread_port): +def test_batched_generation_request(grpc_client): responses = list( grpc_client.make_request( [ diff --git a/tests/test_http_server.py b/tests/test_http_server.py index a664aef..a8a7114 100644 --- a/tests/test_http_server.py +++ b/tests/test_http_server.py @@ -1,12 +1,12 @@ import requests -def test_startup(http_server_url, _http_server): +def test_startup(http_server_url, _servers): """Test that the http_server fixture starts up properly.""" requests.get(f"{http_server_url}/health").raise_for_status() -def test_completions(http_server_url, _http_server): +def test_completions(http_server_url, _servers): response = requests.get(f"{http_server_url}/v1/models") response.raise_for_status() @@ -29,6 +29,6 @@ def test_completions(http_server_url, _http_server): assert generated_text -def test_metrics(http_server_url, _http_server): +def test_metrics(http_server_url, _servers): response = requests.get(f"{http_server_url}/metrics") response.raise_for_status() diff --git a/tests/utils.py b/tests/utils.py index 1ab56bf..d353e38 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,6 +30,10 @@ _T = TypeVar("_T") +class TaskFailedError(Exception): + pass + + def get_random_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) @@ -40,8 +44,8 @@ def get_random_port(): def wait_until( pred: Callable[..., _T], - timeout: float = 30, - pause: float = 0.5, + timeout: float = 60, + pause: float = 5, ) -> _T: start = time.perf_counter() exc = None @@ -49,10 +53,14 @@ def wait_until( while (time.perf_counter() - start) < timeout: try: value = pred() + except TaskFailedError: + raise except Exception as e: # noqa: BLE001 + print(f"Got {e=}") exc = e else: return value + time.sleep(pause) raise TimeoutError("timed out waiting") from exc @@ -72,7 +80,7 @@ def get_server_certificate(host: str, port: int) -> str: # https://github.com/python/cpython/pull/16820 return ssl.get_server_certificate((host, port)) - context = ssl.SSLContext() + context = ssl.SSLContext() # type: ignore[unreachable] # false positive for python>=3.10 with ( socket.create_connection((host, port)) as sock, @@ -119,6 +127,8 @@ def make_request( max_new_tokens: int = 10, adapter_id: str | None = None, ) -> GenerationResponse | Sequence[GenerationResponse]: + # assert model_id # FIXME: is model_id required? + if single_request := isinstance(text, str): text = [text] @@ -146,6 +156,8 @@ def make_request_stream( model_id: str | None = None, max_new_tokens: int = 10, ) -> Generator[GenerationResponse, None, None]: + # assert model_id # FIXME: is model_id required? + request = SingleGenerationRequest( model_id=model_id, request=GenerationRequest(text=text), From d16f7ccc0fbfb90172e0ba5219b8141484b85d11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Tue, 6 Aug 2024 16:20:30 +0200 Subject: [PATCH 2/9] pyproject: disable -W Error --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f10a6f0..dd72dd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ vllm_tgis_adapter = [ ] [tool.pytest.ini_options] -addopts = "-ra -W error::grpc.experimental.ExperimentalApiWarning" +addopts = "-ra" [tool.coverage.run] branch = true From 10f6db95a0564e79d121855000b92efa9b9a50fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Tue, 6 Aug 2024 16:29:35 +0200 Subject: [PATCH 3/9] pyproject: bump minimum vLLM version to 0.5.4 --- .github/workflows/tests.yaml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index ac55031..9a14a6c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -32,7 +32,7 @@ jobs: pyv: ["3.11"] vllm_version: # - "" # skip the pypi version as it will not work on CPU - - "git+https://github.com/vllm-project/vllm@v0.5.3.post1" + - "git+https://github.com/vllm-project/vllm@v0.5.4" - "git+https://github.com/vllm-project/vllm@main" - "git+https://github.com/opendatahub-io/vllm@main" diff --git a/pyproject.toml b/pyproject.toml index dd72dd7..a0f2479 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ requires-python = ">=3.9" dynamic = ["version"] dependencies = [ - "vllm>=0.5.3.post1", + "vllm>=0.5.4", "prometheus_client==0.20.0", "grpcio==1.62.2", "grpcio-health-checking==1.62.2", From daf601e7a257088bf3c94a5c2c1dcd8cf6334901 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Tue, 6 Aug 2024 18:27:00 +0200 Subject: [PATCH 4/9] cleanup prompt adapters imports --- src/vllm_tgis_adapter/grpc/adapters.py | 2 +- src/vllm_tgis_adapter/grpc/grpc_server.py | 22 ++++------------------ 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/src/vllm_tgis_adapter/grpc/adapters.py b/src/vllm_tgis_adapter/grpc/adapters.py index f5a9144..655fb41 100644 --- a/src/vllm_tgis_adapter/grpc/adapters.py +++ b/src/vllm_tgis_adapter/grpc/adapters.py @@ -156,6 +156,6 @@ def _reject_bad_adapter_id(adapter_id: str) -> None: if not VALID_ADAPTER_ID_PATTERN.fullmatch(adapter_id): TGISValidationError.InvalidAdapterID.error(adapter_id) - cwd = Path().resolve() + cwd = Path().cwd() if not Path(adapter_id).resolve().is_relative_to(cwd): TGISValidationError.InvalidAdapterID.error(adapter_id) diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index bd1e9c8..4ae88da 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -42,6 +42,7 @@ TGISStatLogger, ) +from .adapters import AdapterStore, validate_adapters from .pb import generation_pb2_grpc from .pb.generation_pb2 import DESCRIPTOR as _GENERATION_DESCRIPTOR from .pb.generation_pb2 import ( @@ -56,14 +57,6 @@ ) from .validation import validate_input, validate_params -try: - from .adapters import AdapterStore, validate_adapters -except ImportError: - adapters_available = False -else: - adapters_available = True - - if TYPE_CHECKING: import argparse from collections.abc import AsyncIterator, MutableSequence @@ -76,6 +69,7 @@ from vllm.lora.request import LoRARequest from vllm.sequence import Logprob + from .adapters import PromptAdapterRequest from .pb.generation_pb2 import ( BatchedGenerationRequest, BatchedTokenizeRequest, @@ -224,11 +218,7 @@ async def Generate( start_time = time.time() service_metrics.count_generate_request(len(request.requests)) request_id = self.request_id(context) - adapter_kwargs = ( - await self._validate_adapters(request, context) - if adapters_available - else {} - ) + adapter_kwargs = await self._validate_adapters(request, context) tokenizer = await self._get_tokenizer(adapter_kwargs) sampling_params, deadline = await self._validate_and_convert_params( @@ -326,11 +316,7 @@ async def GenerateStream( start_time = time.time() service_metrics.count_generate_request() request_id = self.request_id(context) - adapter_kwargs = ( - await self._validate_adapters(request, context) - if adapters_available - else {} - ) + adapter_kwargs = await self._validate_adapters(request, context) tokenizer = await self._get_tokenizer(adapter_kwargs) sampling_params, deadline = await self._validate_and_convert_params( From 5821b05a8353f81a1773e6e982ac5e3c5350a7c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Tue, 6 Aug 2024 19:42:04 +0200 Subject: [PATCH 5/9] speed up grpc_healthcheck startup --- src/vllm_tgis_adapter/healthcheck.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/vllm_tgis_adapter/healthcheck.py b/src/vllm_tgis_adapter/healthcheck.py index 1dd5180..2f804e7 100644 --- a/src/vllm_tgis_adapter/healthcheck.py +++ b/src/vllm_tgis_adapter/healthcheck.py @@ -9,8 +9,6 @@ from grpc_health.v1.health_pb2 import HealthCheckRequest from grpc_health.v1.health_pb2_grpc import Health -from vllm_tgis_adapter.grpc.grpc_server import TextGenerationService - warnings.simplefilter( action="ignore", category=grpc.experimental.ExperimentalApiWarning ) @@ -85,7 +83,10 @@ def parse_args() -> argparse.Namespace: type=str, help="Name of the service to check", required=False, - default=TextGenerationService.SERVICE_NAME, + # the value below should match: + # vllm_tgis_adapter.grpc.grpc_server.TextGenerationService.SERVICE_NAME + # which we do not import here to avoid import overhead + default="fmaas.GenerationService", ) return parser.parse_args() From 53ab050c996f0cd5fdd79bb33ce0176e1d16acfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Tue, 6 Aug 2024 20:27:19 +0200 Subject: [PATCH 6/9] __main__: improve error handling in start_servers() --- src/vllm_tgis_adapter/__main__.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/vllm_tgis_adapter/__main__.py b/src/vllm_tgis_adapter/__main__.py index 87f00c8..c7164fc 100644 --- a/src/vllm_tgis_adapter/__main__.py +++ b/src/vllm_tgis_adapter/__main__.py @@ -2,6 +2,7 @@ import asyncio import signal +from concurrent.futures import FIRST_EXCEPTION from typing import TYPE_CHECKING import uvloop @@ -53,11 +54,23 @@ async def override_signal_handler() -> None: tasks.append(loop.create_task(override_signal_handler())) - try: - await asyncio.wait(tasks) - except asyncio.CancelledError: - for task in tasks: - task.cancel() + done, pending = await asyncio.wait( + tasks, + return_when=FIRST_EXCEPTION, + ) + for task in pending: + task.cancel() + + while done: + task = done.pop() + exc = task.exception() + if not exc: + continue + + name = task.get_name() + coro_name = task.get_coro().__name__ + + raise RuntimeError(f"task={name} ({coro_name})") from exc if __name__ == "__main__": From 958a88debe8487e9c721e515c269a1b9b047da2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Tue, 6 Aug 2024 20:51:55 +0200 Subject: [PATCH 7/9] add check_for_failed_tasks --- src/vllm_tgis_adapter/__main__.py | 25 ++++++++++--------------- src/vllm_tgis_adapter/utils.py | 20 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 15 deletions(-) create mode 100644 src/vllm_tgis_adapter/utils.py diff --git a/src/vllm_tgis_adapter/__main__.py b/src/vllm_tgis_adapter/__main__.py index c7164fc..695a1ba 100644 --- a/src/vllm_tgis_adapter/__main__.py +++ b/src/vllm_tgis_adapter/__main__.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import signal from concurrent.futures import FIRST_EXCEPTION from typing import TYPE_CHECKING @@ -17,6 +18,7 @@ from .http import run_http_server from .logging import init_logger from .tgis_utils.args import EnvVarArgumentParser, add_tgis_args, postprocess_tgis_args +from .utils import check_for_failed_tasks if TYPE_CHECKING: import argparse @@ -54,23 +56,16 @@ async def override_signal_handler() -> None: tasks.append(loop.create_task(override_signal_handler())) - done, pending = await asyncio.wait( - tasks, - return_when=FIRST_EXCEPTION, - ) - for task in pending: - task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.wait( + tasks, + return_when=FIRST_EXCEPTION, + ) - while done: - task = done.pop() - exc = task.exception() - if not exc: - continue - - name = task.get_name() - coro_name = task.get_coro().__name__ + for task in tasks: + task.cancel() - raise RuntimeError(f"task={name} ({coro_name})") from exc + check_for_failed_tasks(tasks) if __name__ == "__main__": diff --git a/src/vllm_tgis_adapter/utils.py b/src/vllm_tgis_adapter/utils.py new file mode 100644 index 0000000..3abc64d --- /dev/null +++ b/src/vllm_tgis_adapter/utils.py @@ -0,0 +1,20 @@ +import asyncio +from collections.abc import Iterable + + +def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> None: + """Check a sequence of tasks exceptions and raise the exception.""" + for task in tasks: + try: + exc = task.exception() + except asyncio.InvalidStateError: + # no exception is set + continue + + if not exc: + continue + + name = task.get_name() + coro_name = task.get_coro().__name__ + + raise RuntimeError(f"task={name} ({coro_name})") from exc From b63b6c570fe2d9736d2a6d449b738db7f96ba5d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Tue, 6 Aug 2024 21:38:26 +0200 Subject: [PATCH 8/9] tests: parametrize over --disable-frontend-multiprocessing --- tests/conftest.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9b7d5bd..95210eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,13 +57,25 @@ def lora_adapter_path(request: pytest.FixtureRequest) -> str: return path +@pytest.fixture( + params=[ + pytest.param(True, id="disable-frontend-multiprocessing=True"), + pytest.param(False, id="disable-frontend-multiprocessing=False"), + ] +) +def disable_frontend_multiprocessing(request): + """Enable or disable the frontend-multiprocessing feature.""" + return request.param + + @pytest.fixture() -def args( +def args( # noqa: PLR0913 request: pytest.FixtureRequest, monkeypatch, grpc_server_port: ArgFixture[int], http_server_port: ArgFixture[int], lora_available: ArgFixture[bool], + disable_frontend_multiprocessing, ) -> argparse.Namespace: """Return parsed CLI arguments for the adapter/vLLM.""" # avoid parsing pytest arguments as vllm/vllm_tgis_adapter arguments @@ -75,6 +87,9 @@ def args( extra_args.extend(("--enable-lora", f"--lora-modules={name}={path}")) + if disable_frontend_multiprocessing: + extra_args.append("--disable-frontend-multiprocessing") + monkeypatch.setattr( sys, "argv", @@ -179,3 +194,13 @@ def target(): task.cancel() t.join() + + # rorkaround: Instantiating the TGISStatLogger multiple times creates + # multiple Gauges etc which can only be instantiated once. + # By unregistering the Collectors from the REGISTRY we can + # work around this problem. + + from prometheus_client.registry import REGISTRY + + for name in list(REGISTRY._collector_to_names.keys()): # noqa: SLF001 + REGISTRY.unregister(name) From 7cb62343c09c629515a593600c009781821f3ecd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Tue, 6 Aug 2024 22:58:34 +0200 Subject: [PATCH 9/9] tests: temporarily disable frontend multiprocessing parametrization --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 95210eb..7e4d18e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,7 +59,7 @@ def lora_adapter_path(request: pytest.FixtureRequest) -> str: @pytest.fixture( params=[ - pytest.param(True, id="disable-frontend-multiprocessing=True"), + # pytest.param(True, id="disable-frontend-multiprocessing=True"), pytest.param(False, id="disable-frontend-multiprocessing=False"), ] )