Skip to content

Commit 4ea85cd

Browse files
authored
Add Middleware Support to PySpring Framework (#14)
1 parent 4f3d1b3 commit 4ea85cd

File tree

9 files changed

+555
-10
lines changed

9 files changed

+555
-10
lines changed

py_spring_core/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
PutMapping,
1111
)
1212
from py_spring_core.core.entities.entity_provider import EntityProvider
13+
from py_spring_core.core.entities.middlewares.middleware import Middleware
14+
from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry
1315
from py_spring_core.core.entities.properties.properties import Properties
1416
from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired
1517
from py_spring_core.event.application_event_publisher import ApplicationEventPublisher
1618
from py_spring_core.event.commons import ApplicationEvent
1719
from py_spring_core.event.application_event_handler_registry import EventListener
1820

19-
__version__ = "0.0.18"
21+
__version__ = "0.0.19"
2022

2123
__all__ = [
2224
"PySpringApplication",
@@ -35,4 +37,6 @@
3537
"ApplicationEventPublisher",
3638
"ApplicationEvent",
3739
"EventListener",
40+
"Middleware",
41+
"MiddlewareRegistry",
3842
]

py_spring_core/core/application/context/application_context.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from inspect import isclass
55
from typing import (
66
Annotated,
7+
Any,
78
Callable,
89
Mapping,
910
Optional,
@@ -25,6 +26,7 @@
2526
from py_spring_core.core.entities.bean_collection import (
2627
BeanCollection,
2728
BeanConflictError,
29+
BeanView,
2830
InvalidBeanError,
2931
)
3032
from py_spring_core.core.entities.component import Component, ComponentScope
@@ -362,7 +364,7 @@ def _inject_bean_collection_dependencies(self, bean_collection_cls: Type[BeanCol
362364
)
363365
self.dependency_injector.inject_dependencies(bean_collection_cls)
364366

365-
def _validate_bean_view(self, view, collection_name: str) -> None:
367+
def _validate_bean_view(self, view: BeanView, collection_name: str) -> None:
366368
"""Validate a bean view before adding it to the container."""
367369
if view.bean_name in self.container_manager.singleton_bean_instance_container:
368370
raise BeanConflictError(
@@ -564,6 +566,10 @@ def init_ioc_container(self) -> None:
564566
# Initialize singleton beans
565567
self.bean_manager.init_singleton_beans()
566568

569+
def inject_dependencies_for_external_object(self, object: Type[Any]) -> None:
570+
"""Inject dependencies for an external object."""
571+
self.dependency_injector.inject_dependencies(object)
572+
567573
def inject_dependencies_for_app_entities(self) -> None:
568574
"""Inject dependencies for all registered app entities."""
569575
containers: list[Mapping[str, Type[AppEntities]]] = [

py_spring_core/core/application/py_spring_application.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from py_spring_core.core.entities.controllers.rest_controller import RestController
2929
from py_spring_core.core.entities.controllers.route_mapping import RouteMapping
3030
from py_spring_core.core.entities.entity_provider import EntityProvider
31+
from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry
3132
from py_spring_core.core.entities.properties.properties import Properties
3233
from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired
3334
from py_spring_core.event.application_event_handler_registry import ApplicationEventHandlerRegistry
@@ -212,8 +213,25 @@ def __init_controllers(self) -> None:
212213
controller._register_decorated_routes(routes)
213214
router = controller.get_router()
214215
self.fastapi.include_router(router)
215-
controller.register_middlewares()
216-
216+
self.__init_middlewares()
217+
logger.debug(f"[CONTROLLER INIT] Controller {name} initialized")
218+
def __init_middlewares(self) -> None:
219+
logger.debug("[MIDDLEWARE INIT] Initialize middlewares...")
220+
self_defined_registry_cls = MiddlewareRegistry.get_subclass()
221+
if self_defined_registry_cls is None:
222+
logger.debug("[MIDDLEWARE INIT] No self defined registry class found")
223+
return
224+
logger.debug(f"[MIDDLEWARE INIT] Self defined registry class: {self_defined_registry_cls.__name__}")
225+
logger.debug(f"[MIDDLEWARE INIT] Inject dependencies for external object: {self_defined_registry_cls.__name__}")
226+
self.app_context.inject_dependencies_for_external_object(self_defined_registry_cls)
227+
registry = self_defined_registry_cls()
228+
229+
middleware_classes = registry.get_middleware_classes()
230+
for middleware_class in middleware_classes:
231+
logger.debug(f"[MIDDLEWARE INIT] Inject dependencies for middleware: {middleware_class.__name__}")
232+
self.app_context.inject_dependencies_for_external_object(middleware_class)
233+
registry.apply_middlewares(self.fastapi)
234+
logger.debug("[MIDDLEWARE INIT] Middlewares initialized")
217235
def __configure_uvicorn_logging(self):
218236
"""Configure Uvicorn to use Loguru instead of default logging."""
219237
# Configure Uvicorn to use Loguru
@@ -239,9 +257,6 @@ def emit(self, record):
239257
logging.basicConfig(handlers=[InterceptHandler()], level=log_level, force=True)
240258

241259
def __run_server(self) -> None:
242-
243-
244-
245260
# Run uvicorn server
246261
uvicorn.run(
247262
self.fastapi,

py_spring_core/core/entities/controllers/rest_controller.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import partial
44

55
from py_spring_core.core.entities.controllers.route_mapping import RouteRegistration
6+
from py_spring_core.core.entities.middlewares.middleware import Middleware
67

78

89
class RestController:
@@ -15,7 +16,7 @@ class RestController:
1516
- Providing access to the FastAPI `APIRouter` and `FastAPI` app instances
1617
- Exposing the controller's configuration, including the URL prefix
1718
18-
Subclasses of `RestController` should override the `register_routes` and `register_middlewares` methods to add their own routes and middleware to the controller.
19+
Subclasses of `RestController` should override the `register_routes` methods to add their own routes and middleware to the controller.
1920
"""
2021

2122
app: FastAPI
@@ -53,8 +54,6 @@ def _register_decorated_routes(self, routes: Iterable[RouteRegistration]) -> Non
5354
name=route.name,
5455
)
5556

56-
def register_middlewares(self) -> None: ...
57-
5857
def get_router(self) -> APIRouter:
5958
return self.router
6059

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from abc import abstractmethod
2+
from typing import Awaitable, Callable
3+
from fastapi import Request, Response
4+
from starlette.middleware.base import BaseHTTPMiddleware
5+
6+
7+
8+
class Middleware(BaseHTTPMiddleware):
9+
"""
10+
Middleware base class, inherits from FastAPI's BaseHTTPMiddleware
11+
Simpler to use, only need to implement the process_request method
12+
"""
13+
14+
@abstractmethod
15+
async def process_request(self, request: Request) -> Response | None:
16+
"""
17+
Method to process requests
18+
19+
Args:
20+
request: FastAPI request object
21+
22+
Returns:
23+
Response | None: If Response is returned, it will be directly returned to the client
24+
If None is returned, continue to execute the next middleware or route handler
25+
"""
26+
pass
27+
28+
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
29+
"""
30+
Middleware dispatch method, automatically called by FastAPI
31+
"""
32+
# First execute custom request processing logic
33+
response = await self.process_request(request)
34+
35+
# If a response is returned, return it directly
36+
if response is not None:
37+
return response
38+
39+
# Otherwise continue to execute the next middleware or route handler
40+
return await call_next(request)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Type
5+
from fastapi import FastAPI
6+
from py_spring_core.core.entities.middlewares.middleware import Middleware
7+
from py_spring_core.core.interfaces.single_inheritance_required import SingleInheritanceRequired
8+
9+
10+
11+
12+
class MiddlewareRegistry(SingleInheritanceRequired["MiddlewareRegistry"], ABC):
13+
"""
14+
Middleware registry for managing all middlewares
15+
16+
This registry pattern eliminates the need for manual middleware registration.
17+
The framework automatically handles middleware registration and execution order.
18+
19+
Multiple middleware execution order:
20+
When multiple middlewares are registered through this registry, they are automatically
21+
applied to the FastAPI application in the order they are returned by get_middleware_classes().
22+
Each middleware wraps the application, forming a stack. The last middleware added is the outermost,
23+
and the first is the innermost.
24+
25+
On the request path, the outermost middleware runs first.
26+
On the response path, it runs last.
27+
28+
For example, if get_middleware_classes() returns [MiddlewareA, MiddlewareB]
29+
This results in the following execution order:
30+
Request: MiddlewareB → MiddlewareA → route
31+
Response: route → MiddlewareA → MiddlewareB
32+
This stacking behavior ensures that middlewares are executed in a predictable and controllable order.
33+
"""
34+
35+
@abstractmethod
36+
def get_middleware_classes(self) -> list[Type[Middleware]]:
37+
"""
38+
Get all registered middleware classes
39+
40+
Returns:
41+
List[Type[Middleware]]: List of middleware classes
42+
"""
43+
pass
44+
45+
def apply_middlewares(self, app: FastAPI) -> FastAPI:
46+
"""
47+
Apply middlewares to FastAPI application
48+
49+
Args:
50+
app: FastAPI application instance
51+
52+
Returns:
53+
FastAPI: FastAPI instance with applied middlewares
54+
"""
55+
for middleware_class in self.get_middleware_classes():
56+
app.add_middleware(middleware_class)
57+
return app
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
3+
from abc import ABC
4+
from typing import Generic, Optional, Type, TypeVar, cast
5+
6+
T = TypeVar('T')
7+
8+
class SingleInheritanceRequired(Generic[T], ABC):
9+
"""
10+
A singleton component is a component that only allow subclasses to be inherited.
11+
"""
12+
13+
@classmethod
14+
def check_only_one_subclass_allowed(cls) -> None:
15+
"""
16+
Check if the subclass is allowed to be inherited.
17+
"""
18+
class_dict: dict[str, Type[SingleInheritanceRequired[T]]] = {}
19+
for subclass in cls.__subclasses__():
20+
if subclass.__name__ in class_dict:
21+
continue
22+
class_dict[subclass.__name__] = subclass
23+
if len(class_dict) > 1:
24+
raise ValueError(f"Only one subclass is allowed for {cls.__name__}, but {len(class_dict)} subclasses: {[subclass.__name__ for subclass in class_dict.values()]} found")
25+
26+
@classmethod
27+
def get_subclass(cls) -> Optional[Type[T]]:
28+
"""
29+
Get the subclass of the component.
30+
"""
31+
cls.check_only_one_subclass_allowed()
32+
if len(cls.__subclasses__()) == 0:
33+
return
34+
return cast(Type[T], cls.__subclasses__()[0])

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ dev = [
7373
"isort>=5.13.2",
7474
"pytest>=8.3.2",
7575
"pytest-mock>=3.14.0",
76+
"pytest-asyncio>=1.1.0",
7677
"types-PyYAML>=6.0.12.20240917",
7778
"types-cachetools>=5.5.0.20240820",
7879
"mypy>=1.11.2"

0 commit comments

Comments
 (0)