diff --git a/pyproject.toml b/pyproject.toml index 8553f35..b07abcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = [ {name = "William Chen", email = "william_w_chen@trendmicro.com"}, ] dependencies = [ - "py-spring-core>=0.0.16", + "py-spring-core>=0.0.23", "opentelemetry-instrumentation-fastapi>=0.55b1", "opentelemetry-sdk>=1.34.1", "opentelemetry-exporter-otlp-proto-http>=1.34.1" diff --git a/pyspring_opentelemetry_exporter/__init__.py b/pyspring_opentelemetry_exporter/__init__.py index 15b3221..bf3c62b 100644 --- a/pyspring_opentelemetry_exporter/__init__.py +++ b/pyspring_opentelemetry_exporter/__init__.py @@ -1,10 +1,14 @@ -from ._exporter import provider_opentelemetry_exporter -from ._request_hook_handler import RequestHookHandler, provide_default_request_hook_handler +from pyspring_opentelemetry_exporter._response_trace_middleware import ResponseTraceMiddleware +from pyspring_opentelemetry_exporter._exporter import provider_opentelemetry_exporter +from pyspring_opentelemetry_exporter._request_hook_handler import RequestHookHandler, provide_default_request_hook_handler +from opentelemetry import trace __all__ = [ + "provide_default_request_hook_handler", "provider_opentelemetry_exporter", "RequestHookHandler", - "provide_default_request_hook_handler" + "ResponseTraceMiddleware", + "trace" ] __version__ = "0.0.1" \ No newline at end of file diff --git a/pyspring_opentelemetry_exporter/_exporter.py b/pyspring_opentelemetry_exporter/_exporter.py index 5d7e78f..9127c4c 100644 --- a/pyspring_opentelemetry_exporter/_exporter.py +++ b/pyspring_opentelemetry_exporter/_exporter.py @@ -1,7 +1,7 @@ - +import py_spring_core.core.utils as framework_utils from py_spring_core import ApplicationContextRequired, EntityProvider, Properties -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Type from fastapi import FastAPI from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry import trace @@ -26,8 +26,12 @@ def _get_tracer_properties(self) -> TracerExporterProperties: props = app_context.get_properties(TracerExporterProperties) assert props is not None return props - + def _check_unimplemented_methods_for_handler(self, handler: Type[RequestHookHandler]) -> None: + unimplemented_methods = framework_utils.get_unimplemented_abstract_methods(handler) + if unimplemented_methods: + raise NotImplementedError(f"Handler {handler.__name__} must implement the following methods: {unimplemented_methods}") + def provider_init(self) -> None: app_context = self.get_application_context() tracer_properties = self._get_tracer_properties() @@ -43,6 +47,8 @@ def provider_init(self) -> None: if self.__class__._handler is None: self.__class__._handler = provide_default_request_hook_handler() + self._check_unimplemented_methods_for_handler(self.__class__._handler.__class__) + trace.set_tracer_provider(provider) app_context.server.add_middleware(ResponseTraceMiddleware) self.inject_instrumentation(app_context.server, self.__class__._handler) diff --git a/pyspring_opentelemetry_exporter/_request_hook_handler.py b/pyspring_opentelemetry_exporter/_request_hook_handler.py index f9cb412..2a1b484 100644 --- a/pyspring_opentelemetry_exporter/_request_hook_handler.py +++ b/pyspring_opentelemetry_exporter/_request_hook_handler.py @@ -2,6 +2,7 @@ from typing import Any, Optional from fastapi import Request, Response from opentelemetry.trace.span import Span +import py_spring_core.core.utils as framework_utils class RequestHookHandler(ABC): @abstractmethod @@ -10,6 +11,9 @@ def server_request_hook(self, span: Span, scope: dict[str, Any]) -> None: ... @abstractmethod def client_request_hook(self, span: Span, scope: dict[str, Any], request: dict[str, Any]) -> None: ... + @abstractmethod + def server_response_hook(self, span: Span, scope: dict[str, Any], response: dict[str, Any]) -> None: ... + @abstractmethod def client_response_hook(self, span: Span, scope: dict[str, Any], response: dict[str, Any]) -> None: ... @@ -51,4 +55,7 @@ def client_response_hook(self, span: Span, scope: dict[str, Any], response: dict span.set_attribute("error.preview", error) def provide_default_request_hook_handler() -> RequestHookHandler: + unimplemented_methods = framework_utils.get_unimplemented_abstract_methods(DefaultRequestHookHandler) + if unimplemented_methods: + raise NotImplementedError(f"DefaultRequestHookHandler must implement the following methods: {unimplemented_methods}") return DefaultRequestHookHandler() \ No newline at end of file diff --git a/pyspring_opentelemetry_exporter/_response_trace_middleware.py b/pyspring_opentelemetry_exporter/_response_trace_middleware.py index 7b584f1..aceab10 100644 --- a/pyspring_opentelemetry_exporter/_response_trace_middleware.py +++ b/pyspring_opentelemetry_exporter/_response_trace_middleware.py @@ -1,13 +1,18 @@ +from py_spring_core import Middleware + import traceback from typing import Awaitable, Callable from fastapi import Request -from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response from starlette.middleware.base import _StreamingResponse from opentelemetry.trace import get_current_span -class ResponseTraceMiddleware(BaseHTTPMiddleware): +class ResponseTraceMiddleware(Middleware): + + async def process_request(self, request: Request) -> Response | None: + return None + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: try: response = await call_next(request)