Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion py_spring_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
PutMapping,
)
from py_spring_core.core.entities.entity_provider import EntityProvider
from py_spring_core.core.entities.middlewares.middleware import Middleware
from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry
from py_spring_core.core.entities.properties.properties import Properties
from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired
from py_spring_core.event.application_event_publisher import ApplicationEventPublisher
from py_spring_core.event.commons import ApplicationEvent
from py_spring_core.event.application_event_handler_registry import EventListener

__version__ = "0.0.18"
__version__ = "0.0.19"

__all__ = [
"PySpringApplication",
Expand All @@ -35,4 +37,6 @@
"ApplicationEventPublisher",
"ApplicationEvent",
"EventListener",
"Middleware",
"MiddlewareRegistry",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from inspect import isclass
from typing import (
Annotated,
Any,
Callable,
Mapping,
Optional,
Expand All @@ -25,6 +26,7 @@
from py_spring_core.core.entities.bean_collection import (
BeanCollection,
BeanConflictError,
BeanView,
InvalidBeanError,
)
from py_spring_core.core.entities.component import Component, ComponentScope
Expand Down Expand Up @@ -362,7 +364,7 @@ def _inject_bean_collection_dependencies(self, bean_collection_cls: Type[BeanCol
)
self.dependency_injector.inject_dependencies(bean_collection_cls)

def _validate_bean_view(self, view, collection_name: str) -> None:
def _validate_bean_view(self, view: BeanView, collection_name: str) -> None:
"""Validate a bean view before adding it to the container."""
if view.bean_name in self.container_manager.singleton_bean_instance_container:
raise BeanConflictError(
Expand Down Expand Up @@ -564,6 +566,10 @@ def init_ioc_container(self) -> None:
# Initialize singleton beans
self.bean_manager.init_singleton_beans()

def inject_dependencies_for_external_object(self, object: Type[Any]) -> None:
"""Inject dependencies for an external object."""
self.dependency_injector.inject_dependencies(object)

def inject_dependencies_for_app_entities(self) -> None:
"""Inject dependencies for all registered app entities."""
containers: list[Mapping[str, Type[AppEntities]]] = [
Expand Down
25 changes: 20 additions & 5 deletions py_spring_core/core/application/py_spring_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from py_spring_core.core.entities.controllers.rest_controller import RestController
from py_spring_core.core.entities.controllers.route_mapping import RouteMapping
from py_spring_core.core.entities.entity_provider import EntityProvider
from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry
from py_spring_core.core.entities.properties.properties import Properties
from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired
from py_spring_core.event.application_event_handler_registry import ApplicationEventHandlerRegistry
Expand Down Expand Up @@ -212,8 +213,25 @@ def __init_controllers(self) -> None:
controller._register_decorated_routes(routes)
router = controller.get_router()
self.fastapi.include_router(router)
controller.register_middlewares()

self.__init_middlewares()
logger.debug(f"[CONTROLLER INIT] Controller {name} initialized")
def __init_middlewares(self) -> None:
logger.debug("[MIDDLEWARE INIT] Initialize middlewares...")
self_defined_registry_cls = MiddlewareRegistry.get_subclass()
if self_defined_registry_cls is None:
logger.debug("[MIDDLEWARE INIT] No self defined registry class found")
return
logger.debug(f"[MIDDLEWARE INIT] Self defined registry class: {self_defined_registry_cls.__name__}")
logger.debug(f"[MIDDLEWARE INIT] Inject dependencies for external object: {self_defined_registry_cls.__name__}")
self.app_context.inject_dependencies_for_external_object(self_defined_registry_cls)
registry = self_defined_registry_cls()

middleware_classes = registry.get_middleware_classes()
for middleware_class in middleware_classes:
logger.debug(f"[MIDDLEWARE INIT] Inject dependencies for middleware: {middleware_class.__name__}")
self.app_context.inject_dependencies_for_external_object(middleware_class)
registry.apply_middlewares(self.fastapi)
logger.debug("[MIDDLEWARE INIT] Middlewares initialized")
def __configure_uvicorn_logging(self):
"""Configure Uvicorn to use Loguru instead of default logging."""
# Configure Uvicorn to use Loguru
Expand All @@ -239,9 +257,6 @@ def emit(self, record):
logging.basicConfig(handlers=[InterceptHandler()], level=log_level, force=True)

def __run_server(self) -> None:



# Run uvicorn server
uvicorn.run(
self.fastapi,
Expand Down
5 changes: 2 additions & 3 deletions py_spring_core/core/entities/controllers/rest_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial

from py_spring_core.core.entities.controllers.route_mapping import RouteRegistration
from py_spring_core.core.entities.middlewares.middleware import Middleware


class RestController:
Expand All @@ -15,7 +16,7 @@ class RestController:
- Providing access to the FastAPI `APIRouter` and `FastAPI` app instances
- Exposing the controller's configuration, including the URL prefix

Subclasses of `RestController` should override the `register_routes` and `register_middlewares` methods to add their own routes and middleware to the controller.
Subclasses of `RestController` should override the `register_routes` methods to add their own routes and middleware to the controller.
"""

app: FastAPI
Expand Down Expand Up @@ -53,8 +54,6 @@ def _register_decorated_routes(self, routes: Iterable[RouteRegistration]) -> Non
name=route.name,
)

def register_middlewares(self) -> None: ...

def get_router(self) -> APIRouter:
return self.router

Expand Down
40 changes: 40 additions & 0 deletions py_spring_core/core/entities/middlewares/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import abstractmethod
from typing import Awaitable, Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware



class Middleware(BaseHTTPMiddleware):
"""
Middleware base class, inherits from FastAPI's BaseHTTPMiddleware
Simpler to use, only need to implement the process_request method
"""

@abstractmethod
async def process_request(self, request: Request) -> Response | None:
"""
Method to process requests

Args:
request: FastAPI request object

Returns:
Response | None: If Response is returned, it will be directly returned to the client
If None is returned, continue to execute the next middleware or route handler
"""
pass

async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
"""
Middleware dispatch method, automatically called by FastAPI
"""
# First execute custom request processing logic
response = await self.process_request(request)

# If a response is returned, return it directly
if response is not None:
return response

# Otherwise continue to execute the next middleware or route handler
return await call_next(request)
57 changes: 57 additions & 0 deletions py_spring_core/core/entities/middlewares/middleware_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@


from abc import ABC, abstractmethod
from typing import Type
from fastapi import FastAPI
from py_spring_core.core.entities.middlewares.middleware import Middleware
from py_spring_core.core.interfaces.single_inheritance_required import SingleInheritanceRequired




class MiddlewareRegistry(SingleInheritanceRequired["MiddlewareRegistry"], ABC):
"""
Middleware registry for managing all middlewares

This registry pattern eliminates the need for manual middleware registration.
The framework automatically handles middleware registration and execution order.

Multiple middleware execution order:
When multiple middlewares are registered through this registry, they are automatically
applied to the FastAPI application in the order they are returned by get_middleware_classes().
Each middleware wraps the application, forming a stack. The last middleware added is the outermost,
and the first is the innermost.

On the request path, the outermost middleware runs first.
On the response path, it runs last.

For example, if get_middleware_classes() returns [MiddlewareA, MiddlewareB]
This results in the following execution order:
Request: MiddlewareB β†’ MiddlewareA β†’ route
Response: route β†’ MiddlewareA β†’ MiddlewareB
This stacking behavior ensures that middlewares are executed in a predictable and controllable order.
"""

@abstractmethod
def get_middleware_classes(self) -> list[Type[Middleware]]:
"""
Get all registered middleware classes

Returns:
List[Type[Middleware]]: List of middleware classes
"""
pass

def apply_middlewares(self, app: FastAPI) -> FastAPI:
"""
Apply middlewares to FastAPI application

Args:
app: FastAPI application instance

Returns:
FastAPI: FastAPI instance with applied middlewares
"""
for middleware_class in self.get_middleware_classes():
app.add_middleware(middleware_class)
return app
34 changes: 34 additions & 0 deletions py_spring_core/core/interfaces/single_inheritance_required.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@


from abc import ABC
from typing import Generic, Optional, Type, TypeVar, cast

T = TypeVar('T')

class SingleInheritanceRequired(Generic[T], ABC):
"""
A singleton component is a component that only allow subclasses to be inherited.
"""

@classmethod
def check_only_one_subclass_allowed(cls) -> None:
"""
Check if the subclass is allowed to be inherited.
"""
class_dict: dict[str, Type[SingleInheritanceRequired[T]]] = {}
for subclass in cls.__subclasses__():
if subclass.__name__ in class_dict:
continue
class_dict[subclass.__name__] = subclass
if len(class_dict) > 1:
raise ValueError(f"Only one subclass is allowed for {cls.__name__}, but {len(class_dict)} subclasses: {[subclass.__name__ for subclass in class_dict.values()]} found")

@classmethod
def get_subclass(cls) -> Optional[Type[T]]:
"""
Get the subclass of the component.
"""
cls.check_only_one_subclass_allowed()
if len(cls.__subclasses__()) == 0:
return
return cast(Type[T], cls.__subclasses__()[0])
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dev = [
"isort>=5.13.2",
"pytest>=8.3.2",
"pytest-mock>=3.14.0",
"pytest-asyncio>=1.1.0",
"types-PyYAML>=6.0.12.20240917",
"types-cachetools>=5.5.0.20240820",
"mypy>=1.11.2"
Expand Down
Loading