diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index ede435a..412ce0a 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -12,7 +12,7 @@ from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider from py_spring_core.core.entities.middlewares.middleware import Middleware from py_spring_core.core.entities.middlewares.middleware_registry import ( - MiddlewareRegistry, + MiddlewareRegistry, MiddlewareConfiguration ) from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.interfaces.application_context_required import ( @@ -44,6 +44,7 @@ "EventListener", "Middleware", "MiddlewareRegistry", + "MiddlewareConfiguration", "GracefulShutdownHandler", - "ShutdownType", -] + "ShutdownType" +] \ No newline at end of file diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index cc39d1c..cb1dd64 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Callable, Iterable, Optional, Type +from typing import Any, Callable, Iterable, Optional, Type, TypeVar import uvicorn from fastapi import APIRouter, FastAPI @@ -30,6 +30,7 @@ from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider from py_spring_core.core.entities.middlewares.middleware import Middleware from py_spring_core.core.entities.middlewares.middleware_registry import ( + MiddlewareConfiguration, MiddlewareRegistry, ) from py_spring_core.core.entities.properties.properties import Properties @@ -45,6 +46,8 @@ import py_spring_core.core.utils as framework_utils + +SingleInheritanceRequiredT = TypeVar("SingleInheritanceRequiredT", bound=SingleInheritanceRequired) class PySpringApplication: """ The PySpringApplication class is the main entry point for the PySpring application. @@ -237,7 +240,7 @@ def __init_controllers(self) -> None: self.fastapi.include_router(router) logger.debug(f"[CONTROLLER INIT] Controller {name} initialized") - def _init_external_handler(self, base_class: Type[SingleInheritanceRequired]) -> Type[Any] | None: + def _init_external_handler(self, base_class: Type[SingleInheritanceRequiredT]) -> Type[SingleInheritanceRequiredT] | None: """Initialize an external handler (middleware registry or graceful shutdown handler). Args: @@ -276,12 +279,15 @@ def _init_external_handler(self, base_class: Type[SingleInheritanceRequired]) -> def __init_middlewares(self) -> None: handler_type = MiddlewareRegistry.__name__ logger.debug(f"[{handler_type} INIT] Initialize middlewares...") - registry_cls = self._init_external_handler(MiddlewareRegistry) - if registry_cls is None: + middeware_configuration_cls = self._init_external_handler(MiddlewareConfiguration) + if middeware_configuration_cls is None: return - - registry: MiddlewareRegistry = registry_cls() + registry = MiddlewareRegistry() + logger.info(f"[{handler_type} INIT] Setup middlewares for registry: {registry.__class__.__name__}") + middeware_configuration_cls().configure_middlewares(registry) + logger.info(f"[{handler_type} INIT] Middlewares setup for registry: {registry.__class__.__name__} completed") middleware_classes: list[Type[Middleware]] = registry.get_middleware_classes() + logger.info(f"[{handler_type} INIT] Middleware classes: {', '.join([middleware_class.__name__ for middleware_class in middleware_classes])}") for middleware_class in middleware_classes: logger.debug( f"[{handler_type} INIT] Inject dependencies for middleware: {middleware_class.__name__}" @@ -305,7 +311,7 @@ def __init_graceful_shutdown(self) -> None: self.shutdown_handler = handler_cls( timeout_seconds=shutdown_config.timeout_seconds, timeout_enabled=shutdown_config.enabled - ) + ) # type: ignore logger.debug(f"[{handler_type} INIT] Graceful shutdown initialized") def __configure_uvicorn_logging(self): diff --git a/py_spring_core/core/entities/middlewares/middleware_registry.py b/py_spring_core/core/entities/middlewares/middleware_registry.py index 348ba00..6fd7d93 100644 --- a/py_spring_core/core/entities/middlewares/middleware_registry.py +++ b/py_spring_core/core/entities/middlewares/middleware_registry.py @@ -9,7 +9,7 @@ ) -class MiddlewareRegistry(SingleInheritanceRequired["MiddlewareRegistry"], ABC): +class MiddlewareRegistry: """ Middleware registry for managing all middlewares @@ -31,20 +31,153 @@ class MiddlewareRegistry(SingleInheritanceRequired["MiddlewareRegistry"], ABC): Response: route → MiddlewareA → MiddlewareB This stacking behavior ensures that middlewares are executed in a predictable and controllable order. """ - - @abstractmethod + + def __init__(self): + """ + Initialize the middleware registry. + """ + self._middlewares: list[Type[Middleware]] = [] + + def add_middleware(self, middleware_class: Type[Middleware]) -> None: + """ + Add middleware to the end of the list. + + Args: + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + self._middlewares.append(middleware_class) + + def add_at_index(self, index: int, middleware_class: Type[Middleware]) -> None: + """ + Insert middleware at a specific index position. + + Args: + index: The position to insert at (0-based) + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered or index is invalid + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + if index < 0 or index > len(self._middlewares): + raise ValueError(f"Index {index} is out of range (0-{len(self._middlewares)})") + self._middlewares.insert(index, middleware_class) + + def add_before(self, target_middleware: Type[Middleware], middleware_class: Type[Middleware]) -> None: + """ + Insert middleware before the target middleware. + + Args: + target_middleware: The middleware to insert before + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered or target not found + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + if target_middleware not in self._middlewares: + raise ValueError(f"Target middleware {target_middleware.__name__} not found") + index = self._middlewares.index(target_middleware) + self._middlewares.insert(index, middleware_class) + + def add_after(self, target_middleware: Type[Middleware], middleware_class: Type[Middleware]) -> None: + """ + Insert middleware after the target middleware. + + Args: + target_middleware: The middleware to insert after + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered or target not found + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + if target_middleware not in self._middlewares: + raise ValueError(f"Target middleware {target_middleware.__name__} not found") + index = self._middlewares.index(target_middleware) + self._middlewares.insert(index + 1, middleware_class) + + def remove_middleware(self, middleware_class: Type[Middleware]) -> None: + """ + Remove a middleware from the registry. + + Args: + middleware_class: The middleware class to remove + + Raises: + ValueError: If middleware is not found + """ + if middleware_class not in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} not found") + self._middlewares.remove(middleware_class) + + def clear_middlewares(self) -> None: + """Remove all middlewares from the registry.""" + self._middlewares.clear() + + def has_middleware(self, middleware_class: Type[Middleware]) -> bool: + """ + Check if a middleware is registered. + + Args: + middleware_class: The middleware class to check + + Returns: + bool: True if middleware is registered, False otherwise + """ + return middleware_class in self._middlewares + + def get_middleware_count(self) -> int: + """ + Get the number of registered middlewares. + + Returns: + int: Number of registered middlewares + """ + return len(self._middlewares) + + def get_middleware_index(self, middleware_class: Type[Middleware]) -> int: + """ + Get the index of a middleware in the registry. + + Args: + middleware_class: The middleware class to find + + Returns: + int: The index of the middleware + + Raises: + ValueError: If middleware is not found + """ + if middleware_class not in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} not found") + return self._middlewares.index(middleware_class) + + + def get_middleware_classes(self) -> list[Type[Middleware]]: """ - Get all registered middleware classes - + Get all registered middleware classes. + Returns: - List[Type[Middleware]]: List of middleware classes + List[Type[Middleware]]: List of middleware classes in registration order """ - pass - + return self._middlewares.copy() + def apply_middlewares(self, app: FastAPI) -> FastAPI: """ - Apply middlewares to FastAPI application + Apply middlewares to FastAPI application. + + Iterates through all registered middlewares and applies them to the FastAPI + application instance in the order they were registered. Args: app: FastAPI application instance @@ -55,3 +188,16 @@ def apply_middlewares(self, app: FastAPI) -> FastAPI: for middleware_class in self.get_middleware_classes(): app.add_middleware(middleware_class) return app + + + +class MiddlewareConfiguration(SingleInheritanceRequired["MiddlewareConfiguration"]): + """ + Middleware configuration for managing middleware registration and execution order. + """ + + def configure_middlewares(self, registry: MiddlewareRegistry) -> None: + """ + Setup middlewares for the registry. + """ + pass \ No newline at end of file diff --git a/py_spring_core/exception_handler/decorator.py b/py_spring_core/exception_handler/decorator.py new file mode 100644 index 0000000..3d60b4e --- /dev/null +++ b/py_spring_core/exception_handler/decorator.py @@ -0,0 +1,18 @@ + + +from functools import wraps +from typing import Any, Callable, Type + + +from py_spring_core.exception_handler.exception_handler_registry import ExceptionHandlerRegistry + + +def ExceptionHandler(exception_cls: Type[Exception]) -> Callable[[Callable[[Exception], Any]], Callable]: + def decorator(func: Callable[[Exception], Any]) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + return func(*args, **kwargs) + ExceptionHandlerRegistry.register(exception_cls, wrapper) + return wrapper + + return decorator \ No newline at end of file diff --git a/py_spring_core/exception_handler/exception_handler_registry.py b/py_spring_core/exception_handler/exception_handler_registry.py new file mode 100644 index 0000000..2861975 --- /dev/null +++ b/py_spring_core/exception_handler/exception_handler_registry.py @@ -0,0 +1,25 @@ + + +from typing import Any, Callable, Type, TypeVar + +from loguru import logger + +E = TypeVar('E', bound=Exception) + +class ExceptionHandlerRegistry: + _handlers: dict[str, Callable[[Any], Any]] = {} + + @classmethod + def register(cls, exception_cls: Type[E], handler: Callable[[E], Any]): + key = exception_cls.__name__ + logger.debug(f"Registering exception handler for {key}: {handler.__name__}") + if key in cls._handlers: + error_message = f"Exception handler for {exception_cls} already registered" + logger.error(error_message) + raise RuntimeError(error_message) + + cls._handlers[exception_cls.__name__] = handler + + @classmethod + def get_handler(cls, exception_cls: Type[E]) -> Callable[[E], Any]: + return cls._handlers[exception_cls.__name__] \ No newline at end of file diff --git a/tests/test_middleware.py b/tests/test_middleware.py index e692bcc..979a0d1 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -8,6 +8,7 @@ from py_spring_core.core.entities.middlewares.middleware import Middleware from py_spring_core.core.entities.middlewares.middleware_registry import ( MiddlewareRegistry, + MiddlewareConfiguration, ) @@ -148,7 +149,11 @@ def test_should_skip_default_returns_false(self, mock_request): 1. The default implementation of should_skip returns False 2. This allows the middleware to process all requests by default """ - middleware = Middleware(app=Mock()) + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = TestMiddleware(app=Mock()) result = middleware.should_skip(mock_request) assert result is False @@ -164,6 +169,9 @@ class SkippingMiddleware(Middleware): def should_skip(self, request: Request) -> bool: return request.method == "GET" + async def process_request(self, request: Request) -> Response | None: + return None + middleware = SkippingMiddleware(app=Mock()) result = middleware.should_skip(mock_request) assert result is True @@ -251,6 +259,9 @@ def should_skip(self, request: Request) -> bool: received_request = request return False + async def process_request(self, request: Request) -> Response | None: + return None + middleware = TestMiddleware(app=Mock()) middleware.should_skip(mock_request) @@ -350,7 +361,11 @@ def test_should_skip_method_signature(self): 2. It takes a Request parameter 3. It returns a boolean value """ - middleware = Middleware(app=Mock()) + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = TestMiddleware(app=Mock()) # Check that should_skip is a method assert hasattr(middleware, 'should_skip') @@ -369,44 +384,237 @@ def test_should_skip_method_signature(self): class TestMiddlewareRegistry: - """Test suite for the MiddlewareRegistry abstract class.""" + """Test suite for the MiddlewareRegistry concrete class.""" @pytest.fixture def fastapi_app(self): """Fixture that provides a fresh FastAPI application instance.""" return FastAPI() - def test_middleware_registry_is_abstract(self): + @pytest.fixture + def registry(self): + """Fixture that provides a fresh MiddlewareRegistry instance.""" + return MiddlewareRegistry() + + @pytest.fixture + def test_middleware_1(self): + """Fixture that provides a test middleware class.""" + class TestMiddleware1(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + return TestMiddleware1 + + @pytest.fixture + def test_middleware_2(self): + """Fixture that provides another test middleware class.""" + class TestMiddleware2(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + return TestMiddleware2 + + def test_middleware_registry_instantiation(self): """ - Test that MiddlewareRegistry class is abstract and cannot be instantiated directly. + Test that MiddlewareRegistry can be instantiated directly. This test verifies that: - 1. MiddlewareRegistry is an abstract base class - 2. Attempting to instantiate it directly raises an error + 1. MiddlewareRegistry is a concrete class + 2. It can be instantiated without errors + 3. Initial state is correct """ - # This test verifies that MiddlewareRegistry is abstract - # We can't test direct instantiation because it's abstract - # Instead, we test that it has the abstract method - assert hasattr(MiddlewareRegistry, "get_middleware_classes") - assert MiddlewareRegistry.get_middleware_classes.__isabstractmethod__ + registry = MiddlewareRegistry() + assert isinstance(registry, MiddlewareRegistry) + assert registry.get_middleware_count() == 0 + assert registry.get_middleware_classes() == [] - def test_get_middleware_classes_is_abstract(self): + def test_add_middleware(self, registry, test_middleware_1): """ - Test that get_middleware_classes method is abstract and must be implemented. + Test adding middleware to the registry. This test verifies that: - 1. get_middleware_classes is an abstract method - 2. Subclasses must implement this method + 1. Middleware can be added successfully + 2. Middleware count increases + 3. Middleware appears in the classes list """ + registry.add_middleware(test_middleware_1) + + assert registry.get_middleware_count() == 1 + assert registry.has_middleware(test_middleware_1) + assert test_middleware_1 in registry.get_middleware_classes() - # Create a concrete subclass without implementing get_middleware_classes - class ConcreteRegistry(MiddlewareRegistry): # type: ignore[abstract] - pass + def test_add_duplicate_middleware_raises_error(self, registry, test_middleware_1): + """ + Test that adding duplicate middleware raises an error. + + This test verifies that: + 1. Adding the same middleware twice raises ValueError + 2. The error message is descriptive + 3. The registry state remains unchanged + """ + registry.add_middleware(test_middleware_1) + + with pytest.raises(ValueError, match="Middleware TestMiddleware1 is already registered"): + registry.add_middleware(test_middleware_1) + + # Verify state hasn't changed + assert registry.get_middleware_count() == 1 + + def test_add_at_index(self, registry, test_middleware_1, test_middleware_2): + """ + Test inserting middleware at specific index. - with pytest.raises(TypeError): - ConcreteRegistry() # type: ignore[abstract] + This test verifies that: + 1. Middleware can be inserted at specific positions + 2. Order is maintained correctly + 3. Index bounds are respected + """ + registry.add_middleware(test_middleware_1) + registry.add_at_index(0, test_middleware_2) + + classes = registry.get_middleware_classes() + assert classes[0] == test_middleware_2 + assert classes[1] == test_middleware_1 + + def test_add_at_invalid_index_raises_error(self, registry, test_middleware_1): + """ + Test that adding at invalid index raises an error. + + This test verifies that: + 1. Invalid indices raise ValueError + 2. Error message includes valid range + """ + with pytest.raises(ValueError, match="Index -1 is out of range"): + registry.add_at_index(-1, test_middleware_1) + + with pytest.raises(ValueError, match="Index 1 is out of range"): + registry.add_at_index(1, test_middleware_1) + + def test_add_before(self, registry, test_middleware_1, test_middleware_2): + """ + Test inserting middleware before another middleware. + + This test verifies that: + 1. Middleware can be inserted before target middleware + 2. Order is correct after insertion + """ + registry.add_middleware(test_middleware_1) + registry.add_before(test_middleware_1, test_middleware_2) + + classes = registry.get_middleware_classes() + assert classes[0] == test_middleware_2 + assert classes[1] == test_middleware_1 + + def test_add_before_nonexistent_target_raises_error(self, registry, test_middleware_1, test_middleware_2): + """ + Test that adding before nonexistent target raises error. + + This test verifies that: + 1. Adding before non-registered middleware raises ValueError + 2. Error message is descriptive + """ + with pytest.raises(ValueError, match="Target middleware TestMiddleware1 not found"): + registry.add_before(test_middleware_1, test_middleware_2) + + def test_add_after(self, registry, test_middleware_1, test_middleware_2): + """ + Test inserting middleware after another middleware. + + This test verifies that: + 1. Middleware can be inserted after target middleware + 2. Order is correct after insertion + """ + registry.add_middleware(test_middleware_1) + registry.add_after(test_middleware_1, test_middleware_2) + + classes = registry.get_middleware_classes() + assert classes[0] == test_middleware_1 + assert classes[1] == test_middleware_2 + + def test_remove_middleware(self, registry, test_middleware_1): + """ + Test removing middleware from the registry. + + This test verifies that: + 1. Middleware can be removed successfully + 2. Middleware count decreases + 3. Middleware no longer appears in classes list + """ + registry.add_middleware(test_middleware_1) + registry.remove_middleware(test_middleware_1) + + assert registry.get_middleware_count() == 0 + assert not registry.has_middleware(test_middleware_1) + assert test_middleware_1 not in registry.get_middleware_classes() + + def test_remove_nonexistent_middleware_raises_error(self, registry, test_middleware_1): + """ + Test that removing nonexistent middleware raises error. + + This test verifies that: + 1. Removing non-registered middleware raises ValueError + 2. Error message is descriptive + """ + with pytest.raises(ValueError, match="Middleware TestMiddleware1 not found"): + registry.remove_middleware(test_middleware_1) + + def test_clear_middlewares(self, registry, test_middleware_1, test_middleware_2): + """ + Test clearing all middlewares from the registry. + + This test verifies that: + 1. All middlewares are removed + 2. Registry returns to initial state + """ + registry.add_middleware(test_middleware_1) + registry.add_middleware(test_middleware_2) + + registry.clear_middlewares() + + assert registry.get_middleware_count() == 0 + assert registry.get_middleware_classes() == [] - def test_apply_middlewares_adds_middleware_to_app(self, fastapi_app): + def test_get_middleware_index(self, registry, test_middleware_1, test_middleware_2): + """ + Test getting the index of a middleware. + + This test verifies that: + 1. Index of registered middleware is returned correctly + 2. Index reflects the actual position in the list + """ + registry.add_middleware(test_middleware_1) + registry.add_middleware(test_middleware_2) + + assert registry.get_middleware_index(test_middleware_1) == 0 + assert registry.get_middleware_index(test_middleware_2) == 1 + + def test_get_middleware_index_nonexistent_raises_error(self, registry, test_middleware_1): + """ + Test that getting index of nonexistent middleware raises error. + + This test verifies that: + 1. Getting index of non-registered middleware raises ValueError + 2. Error message is descriptive + """ + with pytest.raises(ValueError, match="Middleware TestMiddleware1 not found"): + registry.get_middleware_index(test_middleware_1) + + def test_get_middleware_classes_returns_copy(self, registry, test_middleware_1): + """ + Test that get_middleware_classes returns a copy. + + This test verifies that: + 1. Modifying returned list doesn't affect internal state + 2. A copy is returned, not the original list + """ + registry.add_middleware(test_middleware_1) + + classes = registry.get_middleware_classes() + classes.clear() + + # Original registry should be unchanged + assert registry.get_middleware_count() == 1 + assert registry.has_middleware(test_middleware_1) + + def test_apply_middlewares_adds_middleware_to_app(self, registry, fastapi_app, test_middleware_1, test_middleware_2): """ Test that apply_middlewares correctly adds middleware classes to FastAPI app. @@ -415,33 +623,22 @@ def test_apply_middlewares_adds_middleware_to_app(self, fastapi_app): 2. The add_middleware method is called for each middleware class 3. The app is returned unchanged """ - - class TestMiddleware1(Middleware): - async def process_request(self, request: Request) -> Response | None: - return None - - class TestMiddleware2(Middleware): - async def process_request(self, request: Request) -> Response | None: - return None - - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [TestMiddleware1, TestMiddleware2] + registry.add_middleware(test_middleware_1) + registry.add_middleware(test_middleware_2) # Mock the add_middleware method with patch.object(fastapi_app, "add_middleware") as mock_add_middleware: - registry = TestRegistry() result = registry.apply_middlewares(fastapi_app) # Verify add_middleware was called for each middleware class assert mock_add_middleware.call_count == 2 - mock_add_middleware.assert_any_call(TestMiddleware1) - mock_add_middleware.assert_any_call(TestMiddleware2) + mock_add_middleware.assert_any_call(test_middleware_1) + mock_add_middleware.assert_any_call(test_middleware_2) # Verify the app is returned assert result == fastapi_app - def test_apply_middlewares_with_empty_list(self, fastapi_app): + def test_apply_middlewares_with_empty_list(self, registry, fastapi_app): """ Test that apply_middlewares handles empty middleware list correctly. @@ -450,13 +647,7 @@ def test_apply_middlewares_with_empty_list(self, fastapi_app): 2. The app is returned unchanged 3. No errors occur with empty middleware list """ - - class EmptyRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [] - with patch.object(fastapi_app, "add_middleware") as mock_add_middleware: - registry = EmptyRegistry() result = registry.apply_middlewares(fastapi_app) # Verify add_middleware was not called @@ -465,7 +656,7 @@ def get_middleware_classes(self) -> list[type[Middleware]]: # Verify the app is returned assert result == fastapi_app - def test_apply_middlewares_preserves_app_state(self, fastapi_app): + def test_apply_middlewares_preserves_app_state(self, registry, fastapi_app, test_middleware_1): """ Test that apply_middlewares preserves the FastAPI app state. @@ -473,19 +664,11 @@ def test_apply_middlewares_preserves_app_state(self, fastapi_app): 1. The original app object is returned (same reference) 2. No app properties are modified during middleware application """ - - class TestMiddleware(Middleware): - async def process_request(self, request: Request) -> Response | None: - return None - - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [TestMiddleware] + registry.add_middleware(test_middleware_1) # Store original app state original_app_id = id(fastapi_app) - registry = TestRegistry() result = registry.apply_middlewares(fastapi_app) # Verify same app object is returned @@ -493,6 +676,48 @@ def get_middleware_classes(self) -> list[type[Middleware]]: assert result is fastapi_app +class TestMiddlewareConfiguration: + """Test suite for the MiddlewareConfiguration class.""" + + def test_middleware_configuration_inheritance(self): + """ + Test that MiddlewareConfiguration has proper inheritance. + + This test verifies that: + 1. MiddlewareConfiguration inherits from SingleInheritanceRequired + 2. It can be instantiated + """ + class TestConfig(MiddlewareConfiguration): + pass + + config = TestConfig() + assert isinstance(config, MiddlewareConfiguration) + + def test_setup_middlewares_can_be_overridden(self): + """ + Test that setup_middlewares can be overridden to configure middlewares. + + This test verifies that: + 1. setup_middlewares can be overridden + 2. The registry is properly configured when overridden + """ + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestConfig(MiddlewareConfiguration): + def setup_middlewares(self, registry: MiddlewareRegistry) -> None: + registry.add_middleware(TestMiddleware) + + config = TestConfig() + registry = MiddlewareRegistry() + + config.setup_middlewares(registry) + + assert registry.has_middleware(TestMiddleware) + assert registry.get_middleware_count() == 1 + + class TestMiddlewareIntegration: """Integration tests for middleware functionality.""" @@ -523,11 +748,10 @@ async def process_request(self, request: Request) -> Response | None: execution_order.append("second") return None - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [FirstMiddleware, SecondMiddleware] - - registry = TestRegistry() + registry = MiddlewareRegistry() + registry.add_middleware(FirstMiddleware) + registry.add_middleware(SecondMiddleware) + app = registry.apply_middlewares(fastapi_app) # Create a test client to trigger middleware execution @@ -566,11 +790,10 @@ async def process_request(self, request: Request) -> Response | None: execution_order.append("second") return None - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [BlockingMiddleware, SecondMiddleware] - - registry = TestRegistry() + registry = MiddlewareRegistry() + registry.add_middleware(BlockingMiddleware) + registry.add_middleware(SecondMiddleware) + app = registry.apply_middlewares(fastapi_app) @app.get("/test") @@ -587,48 +810,77 @@ async def test_endpoint(): assert response.status_code == 403 assert response.text == "blocked" - def test_middleware_registry_single_inheritance(self): + def test_middleware_registry_with_configuration(self): """ - Test that MiddlewareRegistry enforces single inheritance. + Test using MiddlewareRegistry with MiddlewareConfiguration. This test verifies that: - 1. MiddlewareRegistry implements SingleInheritanceRequired - 2. Multiple inheritance is prevented + 1. MiddlewareConfiguration can configure a MiddlewareRegistry + 2. The configuration is applied correctly """ - # This test assumes SingleInheritanceRequired prevents multiple inheritance - # The actual behavior depends on the implementation of SingleInheritanceRequired + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [] + class TestConfig(MiddlewareConfiguration): + def setup_middlewares(self, registry: MiddlewareRegistry) -> None: + registry.add_middleware(TestMiddleware) - # Should be able to create a single inheritance registry - registry = TestRegistry() - assert isinstance(registry, MiddlewareRegistry) + config = TestConfig() + registry = MiddlewareRegistry() + + config.setup_middlewares(registry) + + assert registry.has_middleware(TestMiddleware) + assert registry.get_middleware_count() == 1 - def test_middleware_type_hints(self): + def test_middleware_execution_order_with_skip_logic(self, fastapi_app): """ - Test that middleware classes have correct type hints. + Test middleware execution order with skip logic. This test verifies that: - 1. get_middleware_classes returns the correct type - 2. process_request has correct parameter and return type hints + 1. Middlewares with skip logic are handled correctly + 2. Order is maintained even when some middlewares skip """ + execution_order = [] - class TestMiddleware(Middleware): + class ConditionalMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + return "/skip" in str(request.url) + async def process_request(self, request: Request) -> Response | None: + execution_order.append("conditional") return None - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [TestMiddleware] + class AlwaysRunMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("always") + return None - registry = TestRegistry() - middleware_classes = registry.get_middleware_classes() + registry = MiddlewareRegistry() + registry.add_middleware(ConditionalMiddleware) + registry.add_middleware(AlwaysRunMiddleware) + + app = registry.apply_middlewares(fastapi_app) - # Verify type hints - assert isinstance(middleware_classes, list) - assert all( - issubclass(middleware_class, Middleware) - for middleware_class in middleware_classes - ) + @app.get("/test") + async def test_endpoint(): + return {"message": "test"} + + @app.get("/skip") + async def skip_endpoint(): + return {"message": "skip"} + + client = TestClient(app) + + # Test normal endpoint + execution_order.clear() + response = client.get("/test") + assert execution_order == ["always", "conditional"] + assert response.status_code == 200 + + # Test skip endpoint + execution_order.clear() + response = client.get("/skip") + assert execution_order == ["always"] # Only AlwaysRunMiddleware should execute + assert response.status_code == 200