-
Notifications
You must be signed in to change notification settings - Fork 7.4k
fix(serve): Fix Ray Serve LLM embeddings endpoint for pooling models #61959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import argparse | ||
| import dataclasses | ||
| import inspect | ||
| import json | ||
| import typing | ||
| from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Tuple, Union | ||
|
|
||
|
|
@@ -329,7 +330,7 @@ async def start(self) -> None: | |
| self._oai_models = getattr(state, "openai_serving_models", None) | ||
| self._oai_serving_chat = getattr(state, "openai_serving_chat", None) | ||
| self._oai_serving_completion = getattr(state, "openai_serving_completion", None) | ||
| self._oai_serving_embedding = getattr(state, "openai_serving_embedding", None) | ||
| self._oai_serving_embedding = getattr(state, "serving_embedding", None) | ||
| self._oai_serving_transcription = getattr( | ||
| state, "openai_serving_transcription", None | ||
| ) | ||
|
|
@@ -579,7 +580,8 @@ async def embeddings( | |
| raw_request: Optional[Request] = RawRequestInfo.to_starlette_request_optional( | ||
| raw_request_info | ||
| ) | ||
| embedding_response = await self._oai_serving_embedding.create_embedding( # type: ignore[attr-defined] | ||
| # vLLM's ServingEmbedding is a callable, not a class with create_embedding | ||
| embedding_response = await self._oai_serving_embedding( | ||
| request, | ||
| raw_request=raw_request, | ||
| ) | ||
|
|
@@ -588,8 +590,14 @@ async def embeddings( | |
| yield ErrorResponse( | ||
| error=ErrorInfo(**embedding_response.error.model_dump()) | ||
| ) | ||
| return | ||
|
|
||
| # vLLM returns a Starlette Response object, extract the JSON content | ||
| if hasattr(embedding_response, 'body'): | ||
| content = json.loads(embedding_response.body) | ||
| yield EmbeddingResponse(**content) | ||
| else: | ||
| yield EmbeddingResponse(**embedding_response.model_dump()) | ||
| yield embedding_response | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error responses mishandled as embedding success responsesMedium Severity vLLM's |
||
|
|
||
| async def transcriptions( | ||
| self, | ||
|
|
||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
hasattrfor duck-typing can be brittle. A more robust way to check ifembedding_responseis a Starlette-like response object is to also check the type of thebodyattribute. StarletteResponseobjects have abodyattribute of typebytes. This avoids potential issues if another type of object with abodyattribute of a different type is returned in the future.