diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 2340789..5e01580 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -10,13 +10,15 @@ PutMapping, ) from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent from py_spring_core.event.application_event_handler_registry import EventListener -__version__ = "0.0.18" +__version__ = "0.0.19" __all__ = [ "PySpringApplication", @@ -35,4 +37,6 @@ "ApplicationEventPublisher", "ApplicationEvent", "EventListener", + "Middleware", + "MiddlewareRegistry", ] \ No newline at end of file diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index 729c528..d1e0e50 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -4,6 +4,7 @@ from inspect import isclass from typing import ( Annotated, + Any, Callable, Mapping, Optional, @@ -25,6 +26,7 @@ from py_spring_core.core.entities.bean_collection import ( BeanCollection, BeanConflictError, + BeanView, InvalidBeanError, ) from py_spring_core.core.entities.component import Component, ComponentScope @@ -362,7 +364,7 @@ def _inject_bean_collection_dependencies(self, bean_collection_cls: Type[BeanCol ) self.dependency_injector.inject_dependencies(bean_collection_cls) - def _validate_bean_view(self, view, collection_name: str) -> None: + def _validate_bean_view(self, view: BeanView, collection_name: str) -> None: """Validate a bean view before adding it to the container.""" if view.bean_name in self.container_manager.singleton_bean_instance_container: raise BeanConflictError( @@ -564,6 +566,10 @@ def init_ioc_container(self) -> None: # Initialize singleton beans self.bean_manager.init_singleton_beans() + def inject_dependencies_for_external_object(self, object: Type[Any]) -> None: + """Inject dependencies for an external object.""" + self.dependency_injector.inject_dependencies(object) + def inject_dependencies_for_app_entities(self) -> None: """Inject dependencies for all registered app entities.""" containers: list[Mapping[str, Type[AppEntities]]] = [ diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 3ceda7c..391bb57 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -28,6 +28,7 @@ from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.controllers.route_mapping import RouteMapping from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired from py_spring_core.event.application_event_handler_registry import ApplicationEventHandlerRegistry @@ -212,8 +213,25 @@ def __init_controllers(self) -> None: controller._register_decorated_routes(routes) router = controller.get_router() self.fastapi.include_router(router) - controller.register_middlewares() - + self.__init_middlewares() + logger.debug(f"[CONTROLLER INIT] Controller {name} initialized") + def __init_middlewares(self) -> None: + logger.debug("[MIDDLEWARE INIT] Initialize middlewares...") + self_defined_registry_cls = MiddlewareRegistry.get_subclass() + if self_defined_registry_cls is None: + logger.debug("[MIDDLEWARE INIT] No self defined registry class found") + return + logger.debug(f"[MIDDLEWARE INIT] Self defined registry class: {self_defined_registry_cls.__name__}") + logger.debug(f"[MIDDLEWARE INIT] Inject dependencies for external object: {self_defined_registry_cls.__name__}") + self.app_context.inject_dependencies_for_external_object(self_defined_registry_cls) + registry = self_defined_registry_cls() + + middleware_classes = registry.get_middleware_classes() + for middleware_class in middleware_classes: + logger.debug(f"[MIDDLEWARE INIT] Inject dependencies for middleware: {middleware_class.__name__}") + self.app_context.inject_dependencies_for_external_object(middleware_class) + registry.apply_middlewares(self.fastapi) + logger.debug("[MIDDLEWARE INIT] Middlewares initialized") def __configure_uvicorn_logging(self): """Configure Uvicorn to use Loguru instead of default logging.""" # Configure Uvicorn to use Loguru @@ -239,9 +257,6 @@ def emit(self, record): logging.basicConfig(handlers=[InterceptHandler()], level=log_level, force=True) def __run_server(self) -> None: - - - # Run uvicorn server uvicorn.run( self.fastapi, diff --git a/py_spring_core/core/entities/controllers/rest_controller.py b/py_spring_core/core/entities/controllers/rest_controller.py index 52ce6a5..93e8933 100644 --- a/py_spring_core/core/entities/controllers/rest_controller.py +++ b/py_spring_core/core/entities/controllers/rest_controller.py @@ -3,6 +3,7 @@ from functools import partial from py_spring_core.core.entities.controllers.route_mapping import RouteRegistration +from py_spring_core.core.entities.middlewares.middleware import Middleware class RestController: @@ -15,7 +16,7 @@ class RestController: - Providing access to the FastAPI `APIRouter` and `FastAPI` app instances - Exposing the controller's configuration, including the URL prefix - Subclasses of `RestController` should override the `register_routes` and `register_middlewares` methods to add their own routes and middleware to the controller. + Subclasses of `RestController` should override the `register_routes` methods to add their own routes and middleware to the controller. """ app: FastAPI @@ -53,8 +54,6 @@ def _register_decorated_routes(self, routes: Iterable[RouteRegistration]) -> Non name=route.name, ) - def register_middlewares(self) -> None: ... - def get_router(self) -> APIRouter: return self.router diff --git a/py_spring_core/core/entities/middlewares/middleware.py b/py_spring_core/core/entities/middlewares/middleware.py new file mode 100644 index 0000000..7eab8ac --- /dev/null +++ b/py_spring_core/core/entities/middlewares/middleware.py @@ -0,0 +1,40 @@ +from abc import abstractmethod +from typing import Awaitable, Callable +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + + + +class Middleware(BaseHTTPMiddleware): + """ + Middleware base class, inherits from FastAPI's BaseHTTPMiddleware + Simpler to use, only need to implement the process_request method + """ + + @abstractmethod + async def process_request(self, request: Request) -> Response | None: + """ + Method to process requests + + Args: + request: FastAPI request object + + Returns: + Response | None: If Response is returned, it will be directly returned to the client + If None is returned, continue to execute the next middleware or route handler + """ + pass + + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + """ + Middleware dispatch method, automatically called by FastAPI + """ + # First execute custom request processing logic + response = await self.process_request(request) + + # If a response is returned, return it directly + if response is not None: + return response + + # Otherwise continue to execute the next middleware or route handler + return await call_next(request) \ No newline at end of file diff --git a/py_spring_core/core/entities/middlewares/middleware_registry.py b/py_spring_core/core/entities/middlewares/middleware_registry.py new file mode 100644 index 0000000..ecff80c --- /dev/null +++ b/py_spring_core/core/entities/middlewares/middleware_registry.py @@ -0,0 +1,57 @@ + + +from abc import ABC, abstractmethod +from typing import Type +from fastapi import FastAPI +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.interfaces.single_inheritance_required import SingleInheritanceRequired + + + + +class MiddlewareRegistry(SingleInheritanceRequired["MiddlewareRegistry"], ABC): + """ + Middleware registry for managing all middlewares + + This registry pattern eliminates the need for manual middleware registration. + The framework automatically handles middleware registration and execution order. + + Multiple middleware execution order: + When multiple middlewares are registered through this registry, they are automatically + applied to the FastAPI application in the order they are returned by get_middleware_classes(). + Each middleware wraps the application, forming a stack. The last middleware added is the outermost, + and the first is the innermost. + + On the request path, the outermost middleware runs first. + On the response path, it runs last. + + For example, if get_middleware_classes() returns [MiddlewareA, MiddlewareB] + This results in the following execution order: + Request: MiddlewareB → MiddlewareA → route + Response: route → MiddlewareA → MiddlewareB + This stacking behavior ensures that middlewares are executed in a predictable and controllable order. + """ + + @abstractmethod + def get_middleware_classes(self) -> list[Type[Middleware]]: + """ + Get all registered middleware classes + + Returns: + List[Type[Middleware]]: List of middleware classes + """ + pass + + def apply_middlewares(self, app: FastAPI) -> FastAPI: + """ + Apply middlewares to FastAPI application + + Args: + app: FastAPI application instance + + Returns: + FastAPI: FastAPI instance with applied middlewares + """ + for middleware_class in self.get_middleware_classes(): + app.add_middleware(middleware_class) + return app \ No newline at end of file diff --git a/py_spring_core/core/interfaces/single_inheritance_required.py b/py_spring_core/core/interfaces/single_inheritance_required.py new file mode 100644 index 0000000..03150be --- /dev/null +++ b/py_spring_core/core/interfaces/single_inheritance_required.py @@ -0,0 +1,34 @@ + + +from abc import ABC +from typing import Generic, Optional, Type, TypeVar, cast + +T = TypeVar('T') + +class SingleInheritanceRequired(Generic[T], ABC): + """ + A singleton component is a component that only allow subclasses to be inherited. + """ + + @classmethod + def check_only_one_subclass_allowed(cls) -> None: + """ + Check if the subclass is allowed to be inherited. + """ + class_dict: dict[str, Type[SingleInheritanceRequired[T]]] = {} + for subclass in cls.__subclasses__(): + if subclass.__name__ in class_dict: + continue + class_dict[subclass.__name__] = subclass + if len(class_dict) > 1: + raise ValueError(f"Only one subclass is allowed for {cls.__name__}, but {len(class_dict)} subclasses: {[subclass.__name__ for subclass in class_dict.values()]} found") + + @classmethod + def get_subclass(cls) -> Optional[Type[T]]: + """ + Get the subclass of the component. + """ + cls.check_only_one_subclass_allowed() + if len(cls.__subclasses__()) == 0: + return + return cast(Type[T], cls.__subclasses__()[0]) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6506f02..f896924 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ dev = [ "isort>=5.13.2", "pytest>=8.3.2", "pytest-mock>=3.14.0", + "pytest-asyncio>=1.1.0", "types-PyYAML>=6.0.12.20240917", "types-cachetools>=5.5.0.20240820", "mypy>=1.11.2" diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..bb5f368 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,389 @@ +import pytest +from unittest.mock import Mock, AsyncMock, patch +from fastapi import FastAPI, Request, Response +from fastapi.testclient import TestClient +from starlette.middleware.base import BaseHTTPMiddleware + +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry + + +class TestMiddleware: + """Test suite for the Middleware base class.""" + + @pytest.fixture + def mock_request(self): + """Fixture that provides a mock FastAPI request.""" + request = Mock(spec=Request) + request.method = "GET" + request.url = "http://test.com/api" + return request + + @pytest.fixture + def mock_call_next(self): + """Fixture that provides a mock call_next function.""" + return AsyncMock() + + def test_middleware_inherits_from_base_http_middleware(self): + """ + Test that Middleware class inherits from BaseHTTPMiddleware. + + This test verifies that: + 1. Middleware is a subclass of BaseHTTPMiddleware + 2. The inheritance relationship is correctly established + """ + assert issubclass(Middleware, BaseHTTPMiddleware) + + def test_middleware_is_abstract(self): + """ + Test that Middleware class is abstract and cannot be instantiated directly. + + This test verifies that: + 1. Middleware is an abstract base class + 2. Attempting to instantiate it directly raises an error + """ + # Test that Middleware is abstract by checking it has abstract methods + assert hasattr(Middleware, 'process_request') + assert Middleware.process_request.__isabstractmethod__ + + def test_process_request_is_abstract(self): + """ + Test that process_request method is abstract and must be implemented. + + This test verifies that: + 1. process_request is an abstract method + 2. Subclasses must implement this method + """ + # Create a concrete subclass without implementing process_request + class ConcreteMiddleware(Middleware): + pass + + # Test that the class is abstract by checking it has abstract methods + assert hasattr(ConcreteMiddleware, 'process_request') + # The method should still be abstract since it wasn't implemented + assert ConcreteMiddleware.process_request.__isabstractmethod__ + + @pytest.mark.asyncio + async def test_dispatch_continues_when_process_request_returns_none(self, mock_request, mock_call_next): + """ + Test that dispatch continues to next middleware when process_request returns None. + + This test verifies that: + 1. When process_request returns None, dispatch continues to call_next + 2. The call_next function is called with the correct request + 3. The response from call_next is returned + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = TestMiddleware(app=Mock()) + result = await middleware.dispatch(mock_request, mock_call_next) + + mock_call_next.assert_called_once_with(mock_request) + assert result == expected_response + + @pytest.mark.asyncio + async def test_dispatch_returns_response_when_process_request_returns_response(self, mock_request, mock_call_next): + """ + Test that dispatch returns response directly when process_request returns a response. + + This test verifies that: + 1. When process_request returns a Response, dispatch returns it directly + 2. call_next is not called when process_request returns a response + 3. The response from process_request is returned unchanged + """ + middleware_response = Response(content="middleware response", status_code=403) + + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return middleware_response + + middleware = TestMiddleware(app=Mock()) + result = await middleware.dispatch(mock_request, mock_call_next) + + mock_call_next.assert_not_called() + assert result == middleware_response + + @pytest.mark.asyncio + async def test_dispatch_passes_request_to_process_request(self, mock_request, mock_call_next): + """ + Test that dispatch passes the request to process_request method. + + This test verifies that: + 1. The request object is correctly passed to process_request + 2. The process_request method receives the exact same request object + """ + received_request = None + + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + nonlocal received_request + received_request = request + return None + + middleware = TestMiddleware(app=Mock()) + await middleware.dispatch(mock_request, mock_call_next) + + assert received_request == mock_request + + +class TestMiddlewareRegistry: + """Test suite for the MiddlewareRegistry abstract class.""" + + @pytest.fixture + def fastapi_app(self): + """Fixture that provides a fresh FastAPI application instance.""" + return FastAPI() + + def test_middleware_registry_is_abstract(self): + """ + Test that MiddlewareRegistry class is abstract and cannot be instantiated directly. + + This test verifies that: + 1. MiddlewareRegistry is an abstract base class + 2. Attempting to instantiate it directly raises an error + """ + # This test verifies that MiddlewareRegistry is abstract + # We can't test direct instantiation because it's abstract + # Instead, we test that it has the abstract method + assert hasattr(MiddlewareRegistry, 'get_middleware_classes') + assert MiddlewareRegistry.get_middleware_classes.__isabstractmethod__ + + def test_get_middleware_classes_is_abstract(self): + """ + Test that get_middleware_classes method is abstract and must be implemented. + + This test verifies that: + 1. get_middleware_classes is an abstract method + 2. Subclasses must implement this method + """ + # Create a concrete subclass without implementing get_middleware_classes + class ConcreteRegistry(MiddlewareRegistry): # type: ignore[abstract] + pass + + with pytest.raises(TypeError): + ConcreteRegistry() # type: ignore[abstract] + + def test_apply_middlewares_adds_middleware_to_app(self, fastapi_app): + """ + Test that apply_middlewares correctly adds middleware classes to FastAPI app. + + This test verifies that: + 1. Middleware classes are added to the FastAPI application + 2. The add_middleware method is called for each middleware class + 3. The app is returned unchanged + """ + class TestMiddleware1(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestMiddleware2(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [TestMiddleware1, TestMiddleware2] + + # Mock the add_middleware method + with patch.object(fastapi_app, 'add_middleware') as mock_add_middleware: + registry = TestRegistry() + result = registry.apply_middlewares(fastapi_app) + + # Verify add_middleware was called for each middleware class + assert mock_add_middleware.call_count == 2 + mock_add_middleware.assert_any_call(TestMiddleware1) + mock_add_middleware.assert_any_call(TestMiddleware2) + + # Verify the app is returned + assert result == fastapi_app + + def test_apply_middlewares_with_empty_list(self, fastapi_app): + """ + Test that apply_middlewares handles empty middleware list correctly. + + This test verifies that: + 1. When no middlewares are registered, no middleware is added + 2. The app is returned unchanged + 3. No errors occur with empty middleware list + """ + class EmptyRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [] + + with patch.object(fastapi_app, 'add_middleware') as mock_add_middleware: + registry = EmptyRegistry() + result = registry.apply_middlewares(fastapi_app) + + # Verify add_middleware was not called + mock_add_middleware.assert_not_called() + + # Verify the app is returned + assert result == fastapi_app + + def test_apply_middlewares_preserves_app_state(self, fastapi_app): + """ + Test that apply_middlewares preserves the FastAPI app state. + + This test verifies that: + 1. The original app object is returned (same reference) + 2. No app properties are modified during middleware application + """ + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [TestMiddleware] + + # Store original app state + original_app_id = id(fastapi_app) + + registry = TestRegistry() + result = registry.apply_middlewares(fastapi_app) + + # Verify same app object is returned + assert id(result) == original_app_id + assert result is fastapi_app + + +class TestMiddlewareIntegration: + """Integration tests for middleware functionality.""" + + @pytest.fixture + def fastapi_app(self): + """Fixture that provides a fresh FastAPI application instance.""" + return FastAPI() + + @pytest.mark.asyncio + async def test_middleware_chain_execution(self, fastapi_app): + """ + Test that multiple middlewares execute in the correct order. + + This test verifies that: + 1. Middlewares are executed in the order they are added + 2. Each middleware can process the request + 3. The chain continues correctly when middlewares return None + """ + execution_order = [] + + class FirstMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("first") + return None + + class SecondMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("second") + return None + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [FirstMiddleware, SecondMiddleware] + + registry = TestRegistry() + app = registry.apply_middlewares(fastapi_app) + + # Create a test client to trigger middleware execution + from fastapi.testclient import TestClient + + @app.get("/test") + async def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + # Verify middlewares were executed in order (FastAPI uses LIFO - Last In, First Out) + assert execution_order == ["second", "first"] + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_middleware_early_return(self, fastapi_app): + """ + Test that middleware can return early and prevent further execution. + + This test verifies that: + 1. When a middleware returns a response, subsequent middlewares are not executed + 2. The route handler is not called when middleware returns early + 3. The response from the middleware is returned to the client + """ + execution_order = [] + + class BlockingMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("blocking") + return Response(content="blocked", status_code=403) + + class SecondMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("second") + return None + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [BlockingMiddleware, SecondMiddleware] + + registry = TestRegistry() + app = registry.apply_middlewares(fastapi_app) + + @app.get("/test") + async def test_endpoint(): + execution_order.append("handler") + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + # Verify only blocking middleware executed (FastAPI uses LIFO - Last In, First Out) + # SecondMiddleware executes first, then BlockingMiddleware returns early + assert execution_order == ["second", "blocking"] + assert response.status_code == 403 + assert response.text == "blocked" + + def test_middleware_registry_single_inheritance(self): + """ + Test that MiddlewareRegistry enforces single inheritance. + + This test verifies that: + 1. MiddlewareRegistry implements SingleInheritanceRequired + 2. Multiple inheritance is prevented + """ + # This test assumes SingleInheritanceRequired prevents multiple inheritance + # The actual behavior depends on the implementation of SingleInheritanceRequired + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [] + + # Should be able to create a single inheritance registry + registry = TestRegistry() + assert isinstance(registry, MiddlewareRegistry) + + def test_middleware_type_hints(self): + """ + Test that middleware classes have correct type hints. + + This test verifies that: + 1. get_middleware_classes returns the correct type + 2. process_request has correct parameter and return type hints + """ + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [TestMiddleware] + + registry = TestRegistry() + middleware_classes = registry.get_middleware_classes() + + # Verify type hints + assert isinstance(middleware_classes, list) + assert all(issubclass(middleware_class, Middleware) for middleware_class in middleware_classes) \ No newline at end of file