From c6945afb4d815fc8b86c98e85840eb1e3891b2fb Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 3 Oct 2024 12:14:33 +0800 Subject: [PATCH 01/42] chore: update README.md --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 47bf0a5..34e2c92 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ To get started with **PySpring**, follow these steps: ### 1. Install the **PySpring** framework by running: -`pip3 install git+https://github.com/PythonSpring/pyspring-core.git` +`pip3 install py-spring-core` ### 2. Create a new Python project and navigate to its directory @@ -38,10 +38,10 @@ To get started with **PySpring**, follow these steps: - Run the application by calling the `run()` method on the `PySpringApplication` object, as shown in the example code below: ```python -from py_spring import PySpringApplication +from py_spring_core import PySpringApplication def main(): - app = PySpringApplication() + app = PySpringApplication("./app-config.json") app.run() if __name__ == "__main__": From da4f6ecaff8e979e1680edf10490720b1469fe81 Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 3 Oct 2024 12:18:19 +0800 Subject: [PATCH 02/42] chore: update patch version --- py_spring_core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index de4cba4..85deab6 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -5,4 +5,4 @@ from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.entity_provider import EntityProvider -__version__ = "0.0.4" \ No newline at end of file +__version__ = "0.0.4.post1" \ No newline at end of file From 55cf568af321499b0d717473a3176a0908ad0c98 Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 3 Oct 2024 12:38:58 +0800 Subject: [PATCH 03/42] update: Update project description --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 45bf6c5..7ece316 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "py_spring_core" dynamic = ["version"] -description = "Default template for PDM package" +description = "PySpring is a Python web framework inspired by Spring Boot, combining FastAPI, SQLModel, and Pydantic for building scalable web applications with auto dependency injection, configuration management, and a web server." authors = [ {name = "William Chen", email = "OW6201231@gmail.com"}, ] From c80e2df983981d79b15f79a7baf01731fcca6a0f Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 3 Oct 2024 12:43:05 +0800 Subject: [PATCH 04/42] update: package version --- py_spring_core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 85deab6..12b2130 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -5,4 +5,4 @@ from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.entity_provider import EntityProvider -__version__ = "0.0.4.post1" \ No newline at end of file +__version__ = "0.0.4.post2" \ No newline at end of file From c9485c8dda353bc09bcf8e17bfefa6d77e8103ea Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 4 Oct 2024 23:56:46 +0800 Subject: [PATCH 05/42] refactor: Refactor type annotations, improve return handling, and update dependencies - Added explicit return of `None` in functions where applicable. - Enhanced type annotations for better clarity and static type checking, particularly in the `JsonConfigRepository`, `ApplicationContext`, and `_PropertiesLoader` classes. - Updated `pyproject.toml`: - Added type stubs for `PyYAML` and `cachetools`. - Added `mypy` for static type checking. - Removed redundant `# type: ignore` comments where not needed. --- py_spring_core/commons/class_scanner.py | 4 ++-- .../templates.py | 4 +++- .../commons/json_config_repository.py | 18 +++++++++--------- .../application/context/application_context.py | 14 +++++++------- .../entities/properties/properties_loader.py | 4 ++-- pyproject.toml | 5 ++++- 6 files changed, 27 insertions(+), 22 deletions(-) diff --git a/py_spring_core/commons/class_scanner.py b/py_spring_core/commons/class_scanner.py index dfd4496..cb13bff 100644 --- a/py_spring_core/commons/class_scanner.py +++ b/py_spring_core/commons/class_scanner.py @@ -56,10 +56,10 @@ def import_class_from_file( ) -> Optional[Type[object]]: spec = importlib.util.spec_from_file_location(class_name, file_path) if spec is None: - return + return None module = importlib.util.module_from_spec(spec) if spec.loader is None: - return + return None spec.loader.exec_module(module) cls = getattr(module, class_name, None) return cls diff --git a/py_spring_core/commons/config_file_template_generator/templates.py b/py_spring_core/commons/config_file_template_generator/templates.py index fdb974f..b09748f 100644 --- a/py_spring_core/commons/config_file_template_generator/templates.py +++ b/py_spring_core/commons/config_file_template_generator/templates.py @@ -1,3 +1,5 @@ +from typing import Any + app_config_template = { "app_src_target_dir": "./src", "server_config": {"host": "0.0.0.0", "port": 8080, "enabled": True}, @@ -5,4 +7,4 @@ "loguru_config": {"log_file_path": "./logs/app.log", "log_level": "DEBUG"}, } -app_properties_template = {} +app_properties_template:dict[str, Any] = {} diff --git a/py_spring_core/commons/json_config_repository.py b/py_spring_core/commons/json_config_repository.py index cf5e6f3..d77824d 100644 --- a/py_spring_core/commons/json_config_repository.py +++ b/py_spring_core/commons/json_config_repository.py @@ -24,14 +24,14 @@ class JsonConfigRepository(Generic[T]): """ def __init__(self, file_path: str, target_key: Optional[str] = None) -> None: - self.base_model_cls: BaseModel = self.__class__._get_model_cls() # type: ignore + self.base_model_cls: Type[T] = self.__class__._get_model_cls() self.file_path = file_path self.target_key = target_key self._config: T = self._load_config() @classmethod def _get_model_cls(cls) -> Type[T]: - return get_args(tp=cls.__mro__[0].__orig_bases__[0])[0] + return get_args(cls.__orig_bases__[0])[0] # type: ignore def get_config(self) -> T: return self._config @@ -41,12 +41,12 @@ def reload_config(self) -> None: def save_config(self) -> None: is_the_same_class = ( - self._config.__class__.__name__ == self.base_model_cls.__name__ # type: ignore - ) # type: ignore + self._config.__class__.__name__ == self.base_model_cls.__name__ + ) if not is_the_same_class: raise TypeError( - f"[BASE MODEL CLASS TYPE MISMATCH] Base model class of current repository: {self.base_model_cls.__name__} mismatch with config class: {self._config.__class__.__name__}" # type: ignore - ) # type: ignore + f"[BASE MODEL CLASS TYPE MISMATCH] Base model class of current repository: {self.base_model_cls.__name__} mismatch with config class: {self._config.__class__.__name__}" + ) with open(self.file_path, "w") as file: file.write(self._config.model_dump_json(indent=4)) @@ -56,16 +56,16 @@ def save_config_to_target_path(self, file_path: str) -> None: def _load_config(self) -> T: with open(self.file_path, "r") as file: - if BaseModel not in self.base_model_cls.__mro__: # type: ignore + if BaseModel not in self.base_model_cls.__mro__: raise TypeError( "[BASE MODEL INHERITANCE REQUIRED] JsonConfigRepository required model class being inherited from pydantic.BaseModel for marshalling JSON into python object." ) if self.target_key is None: - return self.base_model_cls.model_validate_json(file.read()) # type: ignore + return self.base_model_cls.model_validate_json(file.read()) target_py_object = json.loads(file.read()) if self.target_key not in target_py_object: raise ValueError( f"[TARGET KEY NOT FOUND] Target key: {self.target_key} not found" ) - return self.base_model_cls.model_validate(target_py_object[self.target_key]) # type: ignore + return self.base_model_cls.model_validate(target_py_object[self.target_key]) diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index 9ead72e..68863a0 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -82,11 +82,11 @@ def as_view(self) -> ApplicationContextView: def get_component(self, component_cls: Type[T]) -> Optional[T]: if not issubclass(component_cls, Component): - return + return None component_cls_name = component_cls.get_name() if component_cls_name not in self.component_cls_container: - return + return None scope = component_cls.get_scope() match scope: @@ -94,11 +94,11 @@ def get_component(self, component_cls: Type[T]) -> Optional[T]: optional_instance = self.singleton_component_instance_container.get( component_cls_name ) - return optional_instance # type: ignore + return cast(T, optional_instance) case ComponentScope.Prototype: prototype_instance = component_cls() - return prototype_instance + return cast(T, prototype_instance) def is_within_context(self, _cls: Type[AppEntities]) -> bool: cls_name = _cls.__name__ @@ -116,15 +116,15 @@ def is_within_context(self, _cls: Type[AppEntities]) -> bool: def get_bean(self, object_cls: Type[T]) -> Optional[T]: bean_name = object_cls.__name__ if bean_name not in self.singleton_bean_instance_container: - return + return None optional_instance = self.singleton_bean_instance_container.get(bean_name) - return optional_instance # type: ignore + return cast(T, optional_instance) def get_properties(self, properties_cls: Type[PT]) -> Optional[PT]: properties_cls_name = properties_cls.get_key() if properties_cls_name not in self.properties_cls_container: - return + return None optional_instance = cast( PT, self.singleton_properties_instance_container.get(properties_cls_name) ) diff --git a/py_spring_core/core/entities/properties/properties_loader.py b/py_spring_core/core/entities/properties/properties_loader.py index 08f40a6..058e0ab 100644 --- a/py_spring_core/core/entities/properties/properties_loader.py +++ b/py_spring_core/core/entities/properties/properties_loader.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Type +from typing import Callable, Optional, Type import cachetools import yaml @@ -32,7 +32,7 @@ def __init__( self.properties_classes = properties_classes self.properties_class_map = self._load_classes_as_map() - self.extension_loader_lookup = { + self.extension_loader_lookup:dict[str, Callable[[str], dict]] = { "json": json.loads, "yaml": yaml.safe_load, "yml": yaml.safe_load, diff --git a/pyproject.toml b/pyproject.toml index 7ece316..5474a4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ "uvloop==0.19.0", "watchfiles==0.23.0", "websockets==12.0", - "cachetools>=5.5.0", + "cachetools>=5.5.0" ] requires-python = ">=3.10" readme = "README.md" @@ -76,4 +76,7 @@ dev = [ "isort>=5.13.2", "pytest>=8.3.2", "pytest-mock>=3.14.0", + "types-PyYAML>=6.0.12.20240917", + "types-cachetools>=5.5.0.20240820", + "mypy>=1.11.2" ] From 1394b83c188422fed8cda9f5732c262abcb0e90d Mon Sep 17 00:00:00 2001 From: William Chen Date: Sat, 5 Oct 2024 08:57:56 +0800 Subject: [PATCH 06/42] feat: integrate strict mypy type checking service and refactor type hint checks - Added `TypeCheckingService` to handle Mypy type checking with configurable modes (`basic` and `strict`). - Removed `ApplicationContextTypeChecker` in favor of a unified type-checking service. - Updated `PySpringApplication` to use `TypeCheckingService` and perform type checking based on the application's config. - Introduced `type_checking_mode` in `ApplicationConfig` to allow flexibility between warning-only and strict error modes. - Updated `pyproject.toml` to include `mypy` as a project dependency. - Deleted old type hint checking utilities and related test cases. --- .../templates.py | 1 + .../commons/type_checking_service.py | 27 +++++ .../core/application/application_config.py | 6 + .../application_context_type_checker.py | 40 ------- .../core/application/py_spring_application.py | 24 ++-- py_spring_core/core/utils.py | 56 +-------- pyproject.toml | 3 +- tests/test_utils.py | 112 ------------------ 8 files changed, 49 insertions(+), 220 deletions(-) create mode 100644 py_spring_core/commons/type_checking_service.py delete mode 100644 py_spring_core/core/application/context/application_context_type_checker.py delete mode 100644 tests/test_utils.py diff --git a/py_spring_core/commons/config_file_template_generator/templates.py b/py_spring_core/commons/config_file_template_generator/templates.py index b09748f..4d49eba 100644 --- a/py_spring_core/commons/config_file_template_generator/templates.py +++ b/py_spring_core/commons/config_file_template_generator/templates.py @@ -5,6 +5,7 @@ "server_config": {"host": "0.0.0.0", "port": 8080, "enabled": True}, "properties_file_path": "./application-properties.json", "loguru_config": {"log_file_path": "./logs/app.log", "log_level": "DEBUG"}, + "type_checking_mode": "strict" } app_properties_template:dict[str, Any] = {} diff --git a/py_spring_core/commons/type_checking_service.py b/py_spring_core/commons/type_checking_service.py new file mode 100644 index 0000000..8cdadcf --- /dev/null +++ b/py_spring_core/commons/type_checking_service.py @@ -0,0 +1,27 @@ +import subprocess +from typing import Optional +from loguru import logger + + +class TypeCheckingErrorr(Exception): ... + +class TypeCheckingService: + def __init__(self, target_folder: str) -> None: + self.target_folder = target_folder + self.checking_command = ['mypy', '--disallow-untyped-defs', self.target_folder] + + def type_checking(self) -> Optional[TypeCheckingErrorr]: + logger.info("[MYPY TYPE CHECKING] Mypy checking types for projects...") + # Run mypy and capture stdout and stderr + result = subprocess.run( + self.checking_command, + capture_output=True, # Captures both stdout and stderr + text=True, # Ensures output is returned as a string + check=False # Avoids raising an exception on non-zero exit code + ) + SUCCESS = 0 + if result.returncode != SUCCESS: + error_message = f"\n{result.stdout}" + return TypeCheckingErrorr(error_message) + logger.success(f"Mypy Type Checking Passed: {result.stdout}".strip()) + return None \ No newline at end of file diff --git a/py_spring_core/core/application/application_config.py b/py_spring_core/core/application/application_config.py index ba82ec1..ec8908a 100644 --- a/py_spring_core/core/application/application_config.py +++ b/py_spring_core/core/application/application_config.py @@ -1,3 +1,4 @@ +from enum import Enum from pydantic import BaseModel, ConfigDict, Field from py_spring_core.commons.json_config_repository import ( @@ -20,6 +21,10 @@ class ServerConfig(BaseModel): port: int enabled: bool = Field(default=True) +class TypeCheckingMode(str, Enum): + """Basic will only warning the user, strict will raise an error""" + Basic = "basic" + Strict = "strict" class ApplicationConfig(BaseModel): """ @@ -39,6 +44,7 @@ class ApplicationConfig(BaseModel): server_config: ServerConfig properties_file_path: str loguru_config: LoguruConfig + type_checking_mode: TypeCheckingMode class ApplicationConfigRepository(JsonConfigRepository[ApplicationConfig]): diff --git a/py_spring_core/core/application/context/application_context_type_checker.py b/py_spring_core/core/application/context/application_context_type_checker.py deleted file mode 100644 index d10500e..0000000 --- a/py_spring_core/core/application/context/application_context_type_checker.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Any, Iterable, Mapping, Type - -from loguru import logger -from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.application.context.application_context import ApplicationContext -from py_spring_core.core.utils import TypeHintError, check_type_hints_for_class - - -class ApplicationContextTypeChecker: - def __init__( - self, app_context: ApplicationContext, - skip_class_attrs: list[str], - target_classes: Iterable[Type[Any]], - skipped_classes: Iterable[Type[Any]] - ) -> None: - self.app_context = app_context - self.skip_class_attrs = skip_class_attrs - self.target_classes = target_classes - self.skipped_classes = skipped_classes - - - def check_type_hints_for_context(self, ctx: ApplicationContext) -> None: - containers: list[Mapping[str, Type[AppEntities]]] = [ - ctx.component_cls_container, - ctx.controller_cls_container, - ctx.bean_collection_cls_container, - ctx.properties_cls_container, - ] - for container in containers: - for _cls in container.values(): - if issubclass(_cls, tuple(self.skipped_classes)): - continue - if _cls not in self.target_classes: - try: - check_type_hints_for_class(_cls, skip_attrs=self.skip_class_attrs) - except TypeHintError as error: - logger.warning(f"Type hint error for class {_cls.__name__}: {error}") - raise error - except NameError as error: - ... \ No newline at end of file diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index aa82780..6a1862c 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -6,17 +6,15 @@ from loguru import logger from pydantic import BaseModel, ConfigDict +from py_spring_core.commons.type_checking_service import TypeCheckingService from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.application.context.application_context_type_checker import ( - ApplicationContextTypeChecker, -) from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.commons.class_scanner import ClassScanner from py_spring_core.commons.config_file_template_generator.config_file_template_generator import ( ConfigFileTemplateGenerator, ) from py_spring_core.commons.file_path_scanner import FilePathScanner -from py_spring_core.core.application.application_config import ApplicationConfigRepository +from py_spring_core.core.application.application_config import ApplicationConfigRepository, TypeCheckingMode from py_spring_core.core.application.context.application_context import ( ApplicationContext, ) @@ -93,11 +91,7 @@ def __init__( BeanCollection: self._handle_register_bean_collection, Properties: self._handle_register_properties, } - self.skip_class_attrs = ["Config", "model_config"] - self.app_context_typer_checker = ApplicationContextTypeChecker( - self.app_context, self.skip_class_attrs, self.classes_with_handlers.keys(), - [Properties, BaseModel] - ) + self.type_checking_service = TypeCheckingService(self.app_config.app_src_target_dir) def __configure_logging(self): """Applies the logging configuration using Loguru.""" @@ -170,7 +164,7 @@ def __init_app(self) -> None: self._register_all_entities_from_providers() self._register_app_entities(self.scanned_classes) self._register_entity_providers(self.entity_providers) - self._check_type_hints() + self._type_checking() self.app_context.load_properties() self.app_context.init_ioc_container() self.app_context.inject_dependencies_for_app_entities() @@ -181,8 +175,14 @@ def __init_app(self) -> None: self._init_providers(self.entity_providers) self._handle_singleton_components_life_cycle(ComponentLifeCycle.Init) - def _check_type_hints(self) -> None: - self.app_context_typer_checker.check_type_hints_for_context(self.app_context) + def _type_checking(self) -> None: + optional_error = self.type_checking_service.type_checking() + if optional_error is not None: + match (self.app_config.type_checking_mode): + case TypeCheckingMode.Strict: + raise optional_error + case TypeCheckingMode.Basic: + logger.warning(optional_error) def _handle_singleton_components_life_cycle( self, life_cycle: ComponentLifeCycle diff --git a/py_spring_core/core/utils.py b/py_spring_core/core/utils.py index 629cea1..bbc6e8f 100644 --- a/py_spring_core/core/utils.py +++ b/py_spring_core/core/utils.py @@ -1,7 +1,7 @@ import importlib.util import inspect from pathlib import Path -from typing import Any, Callable, Iterable, Type, get_type_hints +from typing import Iterable, Type from loguru import logger @@ -75,57 +75,3 @@ def dynamically_import_modules( returned_target_classes.add(loaded_class) return returned_target_classes - - -class TypeHintError(Exception): ... - - -def check_type_hints_for_callable(func: Callable[..., Any]) -> None: - RETURN_ID = "return" - func_qualname_list = func.__qualname__.split(".") - is_class_callable = True if len(func_qualname_list) == 2 else False - class_name = func_qualname_list[0] if is_class_callable else "" - - func_name = func.__name__ - args_type_hints = get_type_hints(func) - arg_type = inspect.getargs(func.__code__) - code_path = inspect.getsourcefile(func) - if RETURN_ID not in args_type_hints: - raise TypeHintError( - f"Type hints for 'return type' not provided for the function: {class_name}.{func_name}, path: {code_path}" - ) - - # plue one is for return type, return type is not included in co_argcount if it is a simple function, - # for member functions, self is included in co_varnames, but not in type hints, so plus 0 - arguments = arg_type.args - argument_count = len(arguments) - if argument_count == 0: - return - if len(args_type_hints) == 0: - raise TypeHintError( - f"Type hints not provided for the function: {class_name}.{func_name}, arguments: {arguments}, current type hints: {args_type_hints}, path: {code_path}" - ) - - if len(args_type_hints) != argument_count: - missing_type_hints_args = [ - arg for arg in arguments if arg not in args_type_hints and arg != "self" - ] - if len(missing_type_hints_args) == 0: - return - raise TypeHintError( - f"Type hints not fully provided: {class_name}.{func_name}, arguments: {arguments}, current type hints: {args_type_hints}, missingg type hints: {','.join(missing_type_hints_args)}, path: {code_path}" - ) - - -def check_type_hints_for_class(_cls: Type[Any], skip_attrs: list[str] = list()) -> None: - for attr in dir(_cls): - if attr in skip_attrs: - continue - if attr.startswith("__"): - continue - attr_obj = getattr(_cls, attr) - if not callable(attr_obj): - continue - if not hasattr(attr_obj, "__annotations__"): - continue - check_type_hints_for_callable(attr_obj) diff --git a/pyproject.toml b/pyproject.toml index 5474a4a..2e5b78f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,8 @@ dependencies = [ "uvloop==0.19.0", "watchfiles==0.23.0", "websockets==12.0", - "cachetools>=5.5.0" + "cachetools>=5.5.0", + "mypy>=1.11.2" ] requires-python = ">=3.10" readme = "README.md" diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index b11b749..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,112 +0,0 @@ -import pytest -from py_spring_core.core.utils import ( - check_type_hints_for_callable, - TypeHintError, - check_type_hints_for_class, -) - - -class TestCheckTypeHintsForCallable: - def test_function_with_argument_type_hints_but_no_return_type(self): - def test_func(a: int, b: str): - pass # No return type hint - - with pytest.raises( - TypeHintError, - match="Type hints for 'return type' not provided for the function", - ): - check_type_hints_for_callable(test_func) - - def test_function_with_no_arguments(self): - def test_func() -> None: - pass - - # No exception should be raised for a function with no arguments - check_type_hints_for_callable(test_func) - - def test_function_with_correct_type_hints(self): - def test_func(a: int, b: str) -> None: - pass - - # No exception should be raised - check_type_hints_for_callable(test_func) - - def test_function_with_no_type_hints(self): - def test_func(a, b) -> None: - pass - - with pytest.raises(TypeHintError, match="Type hints not fully provided"): - check_type_hints_for_callable(test_func) - - def test_function_with_mismatched_type_hints(self): - def test_func(a: int, b, c) -> None: - pass - - # Intentionally using only one type hint - - with pytest.raises(TypeHintError, match="Type hints not fully provided"): - check_type_hints_for_callable(test_func) - - -class TestClassFullyTyped: - def method(self, a: int, b: str) -> bool: - return True - - -class TestClassNotReturnTyped: - def method(self, a: int, b: str): - pass # No return type hint - - -class TestClassWithNoArgs: - def method(self) -> bool: - return True - - -class TestClassNotFullyTyped: - def method(self, a, b: str) -> bool: - return True - - -class TestClassWithNotArsAndReturnTyped: - def method(self): - pass # No type hints at all - - -class TestClassWithVariableInside: - def method(self) -> None: - test = "" - pass # No type hints at all - - -class TestCheckTypeHintsForClass: - def test_class_with_method_and_return_type_hints(self): - # No exception should be raised - check_type_hints_for_class(TestClassFullyTyped) - - def test_class_with_method_missing_return_type_hint(self): - with pytest.raises( - TypeHintError, - match="Type hints for 'return type' not provided for the function", - ): - check_type_hints_for_class(TestClassNotReturnTyped) - - def test_class_with_method_missing_argument_type_hint(self): - with pytest.raises(TypeHintError, match="Type hints not fully provided"): - check_type_hints_for_class(TestClassNotFullyTyped) - - def test_class_with_method_having_no_arguments_but_return_type_hint(self): - # No exception should be raised - check_type_hints_for_class(TestClassWithNoArgs) - - def test_class_with_method_having_no_type_hints(self): - with pytest.raises( - TypeHintError, - match="Type hints for 'return type' not provided for the function", - ): - check_type_hints_for_class(TestClassWithNotArsAndReturnTyped) - - def test_class_with_method_having_variable_inside(self): - check_type_hints_for_class(TestClassWithVariableInside) - - From 94d91e61221e611d46314920b3d7118a668e0141 Mon Sep 17 00:00:00 2001 From: William Chen Date: Sat, 5 Oct 2024 09:03:40 +0800 Subject: [PATCH 07/42] packaging: update package version --- py_spring_core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 12b2130..4b033cc 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -5,4 +5,4 @@ from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.entity_provider import EntityProvider -__version__ = "0.0.4.post2" \ No newline at end of file +__version__ = "0.0.5" \ No newline at end of file From 7b39f87de29edd485c84ec3de9c96f8dd54553e0 Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 10 Oct 2024 15:24:39 +0800 Subject: [PATCH 08/42] refactor: Refactor TypeCheckingService to enhance error logging and add targeted type checking - Introduced `MypyTypeCheckingError` enum for more structured and manageable mypy error types - Updated `TypeCheckingService` to check and log specific mypy typing errors using `target_typing_errors` - Refined the error handling logic to accumulate and return detailed messages for relevant mypy errors - Improved logging of both stderr and stdout during mypy execution for better debugging insights --- .../commons/type_checking_service.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/py_spring_core/commons/type_checking_service.py b/py_spring_core/commons/type_checking_service.py index 8cdadcf..5156df7 100644 --- a/py_spring_core/commons/type_checking_service.py +++ b/py_spring_core/commons/type_checking_service.py @@ -1,14 +1,21 @@ +from enum import Enum import subprocess from typing import Optional from loguru import logger +class MypyTypeCheckingError(str, Enum): + NoUntypedDefs = "no-untyped-def" + class TypeCheckingErrorr(Exception): ... class TypeCheckingService: def __init__(self, target_folder: str) -> None: self.target_folder = target_folder self.checking_command = ['mypy', '--disallow-untyped-defs', self.target_folder] + self.target_typing_errors: list[MypyTypeCheckingError] = [ + MypyTypeCheckingError.NoUntypedDefs + ] def type_checking(self) -> Optional[TypeCheckingErrorr]: logger.info("[MYPY TYPE CHECKING] Mypy checking types for projects...") @@ -19,9 +26,16 @@ def type_checking(self) -> Optional[TypeCheckingErrorr]: text=True, # Ensures output is returned as a string check=False # Avoids raising an exception on non-zero exit code ) - SUCCESS = 0 - if result.returncode != SUCCESS: - error_message = f"\n{result.stdout}" - return TypeCheckingErrorr(error_message) - logger.success(f"Mypy Type Checking Passed: {result.stdout}".strip()) - return None \ No newline at end of file + std_err_lines = result.stderr.split("\n") + std_out_lines = result.stdout.split("\n") + message: str = "" + for line in [*std_err_lines, *std_out_lines]: + if len(line) != 0: + logger.debug(f"[MYPY TYPE CHECKING] {line}") + for error in self.target_typing_errors: + if error in line: + error_message = f"\n{line}" + message += error_message + if len(message) == 0: + return None + return TypeCheckingErrorr(message) \ No newline at end of file From 3d927eea41a40de84056ea2ef5bf97cc889d4542 Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 10 Oct 2024 15:25:47 +0800 Subject: [PATCH 09/42] update: Update framwork version --- py_spring_core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 4b033cc..69cfd7b 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -5,4 +5,4 @@ from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.entity_provider import EntityProvider -__version__ = "0.0.5" \ No newline at end of file +__version__ = "0.0.6" \ No newline at end of file From 8cd302af98379e5026e598d199f71577ebba9db5 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Mon, 6 Jan 2025 19:07:30 +0800 Subject: [PATCH 10/42] Update README.md --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index 34e2c92..a5bd0e2 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,6 @@ # Key Features - Application Initialization: **PySpringApplication** class serves as the main entry point for the **PySpring** application. It initializes the application from a configuration file, scans the application source directory for Python files, and groups them into class files and model files -- **Model Import and Table Creation**: **PySpring** dynamically imports model modules and creates SQLModel tables based on the imported models. It supports SQLAlchemy for database operations. - - **Application Context Management**: **PySpring** manages the application context and dependency injection. It registers application entities such as components, controllers, bean collections, and properties. It also initializes the application context and injects dependencies. - **REST Controllers**: **PySpring** supports RESTful API development using the RestController class. It allows you to define routes, handle HTTP requests, and register middlewares easily. @@ -51,4 +49,4 @@ if __name__ == "__main__": # Contributing -Contributions to **PySpring** are welcome! If you find any issues or have suggestions for improvements, please submit a pull request or open an issue on GitHub. \ No newline at end of file +Contributions to **PySpring** are welcome! If you find any issues or have suggestions for improvements, please submit a pull request or open an issue on GitHub. From cc2370fa983bb04c1fc871fd7cb7ddf3546bfc6a Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Mon, 6 Jan 2025 22:51:28 +0800 Subject: [PATCH 11/42] Create LICENSE --- LICENSE | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..c3bc92f --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 PythonSpring + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. From 1cc1aa80fcb93f59fe3a3c31b4ff555627f8b1c2 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Sun, 15 Jun 2025 01:57:08 +0800 Subject: [PATCH 12/42] feat: Add Support for Qualifiers (#1) --- README.md | 93 +++++++-- py_spring_core/__init__.py | 2 +- .../context/application_context.py | 87 ++++++-- .../core/application/py_spring_application.py | 18 -- py_spring_core/core/entities/component.py | 3 + pyproject.toml | 2 +- tests/test_application_context.py | 4 +- tests/test_component_features.py | 189 ++++++++++++++++++ 8 files changed, 344 insertions(+), 54 deletions(-) create mode 100644 tests/test_component_features.py diff --git a/README.md b/README.md index a5bd0e2..aa1cbfd 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ # **PySpring** Framework -#### **PySpring** is a Python web framework inspired by Spring Boot. It combines FastAPI for the web layer, SQLModel for ORM, and Pydantic for data validation. PySpring provides a structured approach to building scalable web applications with `auto dependency injection`, `auto configuration management` and a `web server` for hosting your application. +#### **PySpring** is a Python web framework inspired by Spring Boot. It combines FastAPI for the web layer, and Pydantic for data validation. PySpring provides a structured approach to building scalable web applications with `auto dependency injection`, `auto configuration management` and a `web server` for hosting your application. -# Key Features -- Application Initialization: **PySpringApplication** class serves as the main entry point for the **PySpring** application. It initializes the application from a configuration file, scans the application source directory for Python files, and groups them into class files and model files +## Key Features +- **Application Initialization**: `PySpringApplication` class serves as the main entry point for the **PySpring** application. It initializes the application from a configuration file, scans the application source directory for Python files, and groups them into class files and model files. - **Application Context Management**: **PySpring** manages the application context and dependency injection. It registers application entities such as components, controllers, bean collections, and properties. It also initializes the application context and injects dependencies. @@ -17,24 +17,40 @@ - **Builtin FastAPI Integration**: **PySpring** integrates with `FastAPI`, a modern, fast (high-performance), web framework for building APIs with Python. It leverages FastAPI's features for routing, request handling, and server configuration. -# Getting Started -To get started with **PySpring**, follow these steps: - -### 1. Install the **PySpring** framework by running: - -`pip3 install py-spring-core` - +## Project Structure +``` +PySpring/ +├── src/ # Source code directory +├── tests/ # Test files +├── logs/ # Application logs +├── py_spring_core/ # Core framework package +├── app-config.json # Application configuration +├── application-properties.json # Application properties +├── main.py # Application entry point +├── pyproject.toml # Project metadata and dependencies +└── README.md # Project documentation +``` -### 2. Create a new Python project and navigate to its directory +## Getting Started -- Implement your application properties, components, controllers, using **PySpring** conventions inside declared source code folder (whcih can be modified the key `app_src_target_dir` inside app-config.json), this controls what folder will be scanned by the framework. +### Prerequisites +- Python 3.10 or higher +- pip (Python package installer) -- Instantiate a `PySpringApplication` object in your main script, passing the path to your application configuration file. +### Installation +1. Install the **PySpring** framework: +```bash +pip install py-spring-core +``` -- Optionally, define and enable any framework modules you want to use. +2. Create a new Python project and navigate to its directory -- Run the application by calling the `run()` method on the `PySpringApplication` object, as shown in the example code below: +3. Set up your application: + - Implement your application properties, components, and controllers using **PySpring** conventions inside the declared source code folder (which can be modified via the `app_src_target_dir` key in `app-config.json`) + - Create an `app-config.json` file for your application configuration + - Create an `application-properties.json` file for your application properties +4. Create your main application script: ```python from py_spring_core import PySpringApplication @@ -45,8 +61,49 @@ def main(): if __name__ == "__main__": main() ``` -- For example project, please refer to this [github repo](https://github.com/NFUChen/PySpring-Example-Project). -# Contributing +5. Run your application: +```bash +python main.py +``` + +For a complete example project, please refer to the [PySpring Example Project](https://github.com/NFUChen/PySpring-Example-Project). + +## Development Setup +1. Clone the repository: +```bash +git clone https://github.com/NFUChen/PySpring.git +cd PySpring +``` + +2. Install development dependencies: +```bash +pip install -e ".[dev]" +``` + +3. Run tests: +```bash +pytest +``` + +## Dependencies +PySpring relies on several key dependencies: +- FastAPI (0.112.0) +- Pydantic (2.8.2) +- Uvicorn (0.30.5) +- Loguru (0.7.2) +- And other supporting packages + +For a complete list of dependencies, see `pyproject.toml`. + +## Contributing + +Contributions to **PySpring** are welcome! If you find any issues or have suggestions for improvements, please: +1. Fork the repository +2. Create a feature branch +3. Commit your changes +4. Push to the branch +5. Create a Pull Request -Contributions to **PySpring** are welcome! If you find any issues or have suggestions for improvements, please submit a pull request or open an issue on GitHub. +## License +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 69cfd7b..f026e4c 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -5,4 +5,4 @@ from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.entity_provider import EntityProvider -__version__ = "0.0.6" \ No newline at end of file +__version__ = "0.0.7" \ No newline at end of file diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index 68863a0..cbd9b5c 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -1,5 +1,6 @@ +from abc import ABC from inspect import isclass -from typing import Callable, Mapping, Optional, Type, TypeVar, cast +from typing import Annotated, Callable, Mapping, Optional, Type, TypeVar, cast, get_origin, get_args from loguru import logger from pydantic import BaseModel @@ -79,20 +80,54 @@ def as_view(self) -> ApplicationContextView: self.singleton_component_instance_container.keys() ), ) + + def _determine_target_cls_name(self, component_cls: Type[T], qualifier: Optional[str]) -> str: + """ + Determine the target class name for a given component class. + This method handles the following cases: + 1. If a qualifier is provided, return it directly. + 2. If the component class is not an ABC, return its name directly. + 3. If the component class is an ABC but has implementations, return its name directly. + 4. If the component class is an ABC and has no implementations, return the name of the first subclass. + 5. If the component class is an ABC and has multiple implementations, raise an error. + """ - def get_component(self, component_cls: Type[T]) -> Optional[T]: - if not issubclass(component_cls, Component): + if qualifier is not None: + return qualifier + + # If it's not an ABC, return its name directly + if not issubclass(component_cls, ABC): + return component_cls.get_name() + + # If it's an ABC but has implementations, return its name directly + if not component_cls.__abstractmethods__: + return component_cls.get_name() + + # For abstract classes that need implementations + subclasses = component_cls.__subclasses__() + if len(subclasses) == 0: + raise ValueError( + f"[ABSTRACT CLASS ERROR] Abstract class {component_cls.__name__} has no subclasses" + ) + + + # Fall back to first subclass if no primary component exists + return subclasses[0].get_name() + + def get_component(self, component_cls: Type[T], qualifier: Optional[str]) -> Optional[T]: + if not issubclass(component_cls, (Component, ABC)): return None - component_cls_name = component_cls.get_name() - if component_cls_name not in self.component_cls_container: + target_cls_name: str = self._determine_target_cls_name(component_cls, qualifier) + + if target_cls_name not in self.component_cls_container: return None scope = component_cls.get_scope() match scope: case ComponentScope.Singleton: optional_instance = self.singleton_component_instance_container.get( - component_cls_name + target_cls_name ) return cast(T, optional_instance) @@ -113,7 +148,7 @@ def is_within_context(self, _cls: Type[AppEntities]) -> bool: or is_within_properties ) - def get_bean(self, object_cls: Type[T]) -> Optional[T]: + def get_bean(self, object_cls: Type[T], qualifier: Optional[str]) -> Optional[T]: bean_name = object_cls.__name__ if bean_name not in self.singleton_bean_instance_container: return None @@ -135,8 +170,9 @@ def register_component(self, component_cls: Type[Component]) -> None: raise TypeError( f"[COMPONENT REGISTRATION ERROR] Component: {component_cls} is not a subclass of Component" ) - component_cls_name = component_cls.get_name() + if component_cls_name in self.component_cls_container: + raise ValueError(f"[COMPONENT REGISTRATION ERROR] Component: {component_cls_name} already registered") self.component_cls_container[component_cls_name] = component_cls def register_controller(self, controller_cls: Type[RestController]) -> None: @@ -210,6 +246,20 @@ def load_properties(self) -> None: self.singleton_properties_instance_container ) + def init_singleton_component(self, component_cls: Type[Component], component_cls_name: str) -> Optional[Component]: + instance: Optional[Component] = None + try: + instance = component_cls() + except Exception as error: + unable_to_init_component_error_prefix = "Can't instantiate abstract class" + if unable_to_init_component_error_prefix in str(error): + logger.warning(f"[INITIALIZING SINGLETON COMPONENT ERROR] Skip initializing singleton component: {component_cls_name} because it is an abstract class") + return + logger.error(f"[INITIALIZING SINGLETON COMPONENT ERROR] Error initializing singleton component: {component_cls_name} with error: {error}") + raise error + + return instance + def init_ioc_container(self) -> None: """ Initializes the IoC (Inversion of Control) container by creating singleton instances of all registered components. @@ -225,7 +275,9 @@ def init_ioc_container(self) -> None: logger.debug( f"[INITIALIZING SINGLETON COMPONENT] Init singleton component: {component_cls_name}" ) - instance = component_cls() + instance = self.init_singleton_component(component_cls, component_cls_name) + if instance is None: + continue self.singleton_component_instance_container[component_cls_name] = instance # for Bean @@ -254,15 +306,21 @@ def init_ioc_container(self) -> None: def _inject_entity_dependencies(self, entity: Type[AppEntities]) -> None: for attr_name, annotated_entity_cls in entity.__annotations__.items(): is_injected: bool = False + # Handle Annotated types + qualifier: Optional[str] = None + if get_origin(annotated_entity_cls) is Annotated: + annotated_entity_cls, qualifier_found = get_args(annotated_entity_cls) + if qualifier_found: + qualifier = qualifier_found if annotated_entity_cls in self.primitive_types: logger.warning( f"[DEPENDENCY INJECTION SKIPPED] Skip inject dependency for attribute: {attr_name} with dependency: {annotated_entity_cls.__name__} because it is primitive type" ) continue - if not isclass(annotated_entity_cls): continue + if issubclass(annotated_entity_cls, Properties): optional_properties = self.get_properties(annotated_entity_cls) if optional_properties is None: @@ -272,10 +330,12 @@ def _inject_entity_dependencies(self, entity: Type[AppEntities]) -> None: setattr(entity, attr_name, optional_properties) continue - entity_getters: list[Callable] = [self.get_component, self.get_bean] + entity_getters: list[Callable[[Type[AppEntities], Optional[str]], Optional[AppEntities]]] = [ + self.get_component, self.get_bean + ] for getter in entity_getters: - optional_entity = getter(annotated_entity_cls) + optional_entity = getter(annotated_entity_cls, qualifier) if optional_entity is not None: setattr(entity, attr_name, optional_entity) is_injected = True @@ -286,8 +346,7 @@ def _inject_entity_dependencies(self, entity: Type[AppEntities]) -> None: f"[DEPENDENCY INJECTION SUCCESS FROM COMPONENT CONTAINER] Inject dependency for {annotated_entity_cls.__name__} in attribute: {attr_name} with dependency: {annotated_entity_cls.__name__} singleton instance" ) continue - - error_message = f"[DEPENDENCY INJECTION FAILED] Fail to inject dependency for attribute: {attr_name} with dependency: {annotated_entity_cls.__name__}, consider register such depency with Compoent decorator" + error_message = f"[DEPENDENCY INJECTION FAILED] Fail to inject dependency for attribute: {attr_name} with dependency: {annotated_entity_cls.__name__} with qualifier: {qualifier}, consider register such depency with Compoent decorator" logger.critical(error_message) raise ValueError(error_message) diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 6a1862c..5ecceee 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -27,12 +27,6 @@ from py_spring_core.core.entities.properties.properties import Properties -class ApplicationFileGroups(BaseModel): - model_config = ConfigDict(protected_namespaces=()) - class_files: set[str] - model_files: set[str] - - class PySpringApplication: """ The PySpringApplication class is the main entry point for the PySpring application. @@ -40,8 +34,6 @@ class PySpringApplication: The class performs the following key tasks: - Initializes the application from a configuration file path - - Scans the application source directory for Python files and groups them into class files and model files - - Dynamically imports the model modules and creates SQLModel tables - Registers application entities (components, controllers, bean collections, properties) with the application context - Initializes the application context and injects dependencies - Handles the lifecycle of singleton components @@ -164,7 +156,6 @@ def __init_app(self) -> None: self._register_all_entities_from_providers() self._register_app_entities(self.scanned_classes) self._register_entity_providers(self.entity_providers) - self._type_checking() self.app_context.load_properties() self.app_context.init_ioc_container() self.app_context.inject_dependencies_for_app_entities() @@ -175,15 +166,6 @@ def __init_app(self) -> None: self._init_providers(self.entity_providers) self._handle_singleton_components_life_cycle(ComponentLifeCycle.Init) - def _type_checking(self) -> None: - optional_error = self.type_checking_service.type_checking() - if optional_error is not None: - match (self.app_config.type_checking_mode): - case TypeCheckingMode.Strict: - raise optional_error - case TypeCheckingMode.Basic: - logger.warning(optional_error) - def _handle_singleton_components_life_cycle( self, life_cycle: ComponentLifeCycle ) -> None: diff --git a/py_spring_core/core/entities/component.py b/py_spring_core/core/entities/component.py index d9b7e8f..3a3aba0 100644 --- a/py_spring_core/core/entities/component.py +++ b/py_spring_core/core/entities/component.py @@ -41,10 +41,13 @@ class Component: """ class Config: + name: str = "" scope: ComponentScope = ComponentScope.Singleton @classmethod def get_name(cls) -> str: + if hasattr(cls.Config, "name") and cls.Config.name: + return cls.Config.name return cls.__name__ @classmethod diff --git a/pyproject.toml b/pyproject.toml index 2e5b78f..32216c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "py_spring_core" dynamic = ["version"] -description = "PySpring is a Python web framework inspired by Spring Boot, combining FastAPI, SQLModel, and Pydantic for building scalable web applications with auto dependency injection, configuration management, and a web server." +description = "PySpring is a Python web framework inspired by Spring Boot, combining FastAPI, and Pydantic for building scalable web applications with auto dependency injection, configuration management, and a web server." authors = [ {name = "William Chen", email = "OW6201231@gmail.com"}, ] diff --git a/tests/test_application_context.py b/tests/test_application_context.py index 1442199..6f4b71b 100644 --- a/tests/test_application_context.py +++ b/tests/test_application_context.py @@ -126,7 +126,7 @@ class TestProperties(Properties): app_context.singleton_component_instance_container["TestComponent"] = ( component_instance ) - retrieved_component = app_context.get_component(TestComponent) + retrieved_component = app_context.get_component(TestComponent, None) assert retrieved_component is component_instance # Test retrieving singleton beans @@ -134,7 +134,7 @@ class TestProperties(Properties): app_context.singleton_bean_instance_container["TestBeanCollection"] = ( bean_instance ) - retrieved_bean = app_context.get_bean(TestBeanCollection) + retrieved_bean = app_context.get_bean(TestBeanCollection, None) assert retrieved_bean is bean_instance # Test retrieving singleton properties diff --git a/tests/test_component_features.py b/tests/test_component_features.py new file mode 100644 index 0000000..116a98a --- /dev/null +++ b/tests/test_component_features.py @@ -0,0 +1,189 @@ +from abc import ABC +import pytest +from typing import Annotated + +from py_spring_core.core.application.context.application_context import ( + ApplicationContext, + ApplicationContextConfig +) +from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.component import ComponentScope + + +class TestComponentFeatures: + """Test suite for component features including primary components, qualifiers, and registration validation.""" + + @pytest.fixture + def app_context(self): + """Fixture that provides a fresh ApplicationContext instance for each test.""" + config = ApplicationContextConfig(properties_path="") + return ApplicationContext(config) + + def test_qualifier_based_injection(self, app_context: ApplicationContext): + """ + Test the qualifier-based dependency injection mechanism. + + This test verifies that: + 1. Multiple implementations of an abstract component can coexist + 2. Specific implementations can be injected using qualifiers + 3. Both primary and non-primary components can be injected using qualifiers + 4. The correct implementation is injected for each qualifier + + The test creates an abstract service with two implementations and verifies + that each can be injected into a consumer using appropriate qualifiers. + """ + # Define abstract base class + class AbstractService(Component): + class Config: + is_primary = True + scope = ComponentScope.Singleton + + def process(self) -> str: + raise NotImplementedError() + + # Define implementations + class ServiceA(AbstractService): + class Config: + is_primary = True + name = "ServiceA" + scope = ComponentScope.Singleton + + def process(self) -> str: + return "Service A processing" + + class ServiceB(AbstractService): + class Config: + is_primary = False + name = "ServiceB" + scope = ComponentScope.Singleton + + def process(self) -> str: + return "Service B processing" + + # Register implementations + app_context.register_component(ServiceA) + app_context.register_component(ServiceB) + app_context.init_ioc_container() + + # Test qualifier-based injection + class ServiceConsumer(Component): + service_a: Annotated[AbstractService, "ServiceA"] + service_b: Annotated[AbstractService, "ServiceB"] + + def post_construct(self) -> None: + assert isinstance(self.service_a, ServiceA) + assert isinstance(self.service_b, ServiceB) + assert self.service_a.process() == "Service A processing" + assert self.service_b.process() == "Service B processing" + + app_context.register_component(ServiceConsumer) + app_context.init_ioc_container() # Initialize the consumer component + app_context.inject_dependencies_for_app_entities() + + def test_duplicate_component_registration(self, app_context: ApplicationContext): + """ + Test the prevention of duplicate component registration. + + This test verifies that: + 1. A component can only be registered once + 2. Attempting to register the same component again raises an error + 3. The error message clearly indicates the duplicate registration + + The test attempts to register the same component twice and verifies + that an appropriate error is raised. + """ + # Define a component + class TestService(Component): + class Config: + name = "TestService" + scope = ComponentScope.Singleton + + def process(self) -> str: + return "Test service processing" + + # Register component first time + app_context.register_component(TestService) + app_context.init_ioc_container() + + # Attempt to register same component again should raise error + with pytest.raises(ValueError, match="\\[COMPONENT REGISTRATION ERROR\\] Component: TestService already registered"): + app_context.register_component(TestService) + app_context.init_ioc_container() + + def test_component_name_override(self, app_context: ApplicationContext): + """ + Test the ability to override component names during registration. + + This test verifies that: + 1. Components can be registered with custom names + 2. The custom name is correctly stored in the component container + 3. The component can be retrieved using the custom name + + The test registers a component with a custom name and verifies + that it is correctly stored in the container. + """ + # Define component with custom name + class TestService(Component): + class Config: + name = "CustomServiceName" + scope = ComponentScope.Singleton + + def process(self) -> str: + return "Test service processing" + + # Register component + app_context.register_component(TestService) + app_context.init_ioc_container() + + # Verify component is registered with custom name + assert "CustomServiceName" in app_context.component_cls_container + assert app_context.component_cls_container["CustomServiceName"] == TestService + + def test_qualifier_with_invalid_component(self, app_context: ApplicationContext): + """ + Test error handling for invalid qualifier usage. + + This test verifies that: + 1. Attempting to inject a component with an invalid qualifier raises an error + 2. The error message clearly indicates the invalid qualifier + 3. The error occurs during dependency injection + + The test attempts to inject a component using a non-existent qualifier + and verifies that an appropriate error is raised. + """ + # Define abstract base class + class AbstractService(Component): + class Config: + is_primary = True + scope = ComponentScope.Singleton + + def process(self) -> str: + raise NotImplementedError() + + # Define implementation + class TestService(AbstractService): + class Config: + is_primary = True + name = "TestService" + scope = ComponentScope.Singleton + + def process(self) -> str: + return "Test service processing" + + # Register implementation + app_context.register_component(TestService) + app_context.init_ioc_container() + + # Test injection with invalid qualifier + class ServiceConsumer(Component): + service: Annotated[AbstractService, "NonExistentService"] + + def post_construct(self) -> None: + pass + + app_context.register_component(ServiceConsumer) + app_context.init_ioc_container() # Initialize the consumer component + + # Attempting to inject with invalid qualifier should raise error + with pytest.raises(ValueError, match="\\[DEPENDENCY INJECTION FAILED\\] Fail to inject dependency for attribute: service with dependency: AbstractService with qualifier: NonExistentService"): + app_context.inject_dependencies_for_app_entities() \ No newline at end of file From 5419a557735b9b8e3669ba4e7eff19d80cb7efd2 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Sun, 15 Jun 2025 11:28:32 +0800 Subject: [PATCH 13/42] Dependency Updates and Configuration Cleanup (#2) --- py_spring_core/__init__.py | 2 +- py_spring_core/core/application/application_config.py | 7 ------- .../core/application/py_spring_application.py | 2 +- .../core/entities/properties/properties_loader.py | 4 ++-- pyproject.toml | 11 ++++------- 5 files changed, 8 insertions(+), 18 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index f026e4c..2c3ead3 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -5,4 +5,4 @@ from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.entity_provider import EntityProvider -__version__ = "0.0.7" \ No newline at end of file +__version__ = "0.0.8" \ No newline at end of file diff --git a/py_spring_core/core/application/application_config.py b/py_spring_core/core/application/application_config.py index ec8908a..ac9e2c9 100644 --- a/py_spring_core/core/application/application_config.py +++ b/py_spring_core/core/application/application_config.py @@ -21,11 +21,6 @@ class ServerConfig(BaseModel): port: int enabled: bool = Field(default=True) -class TypeCheckingMode(str, Enum): - """Basic will only warning the user, strict will raise an error""" - Basic = "basic" - Strict = "strict" - class ApplicationConfig(BaseModel): """ Represents the configuration for the application. @@ -44,8 +39,6 @@ class ApplicationConfig(BaseModel): server_config: ServerConfig properties_file_path: str loguru_config: LoguruConfig - type_checking_mode: TypeCheckingMode - class ApplicationConfigRepository(JsonConfigRepository[ApplicationConfig]): """ diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 5ecceee..cf956ce 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -14,7 +14,7 @@ ConfigFileTemplateGenerator, ) from py_spring_core.commons.file_path_scanner import FilePathScanner -from py_spring_core.core.application.application_config import ApplicationConfigRepository, TypeCheckingMode +from py_spring_core.core.application.application_config import ApplicationConfigRepository from py_spring_core.core.application.context.application_context import ( ApplicationContext, ) diff --git a/py_spring_core/core/entities/properties/properties_loader.py b/py_spring_core/core/entities/properties/properties_loader.py index 058e0ab..525ae95 100644 --- a/py_spring_core/core/entities/properties/properties_loader.py +++ b/py_spring_core/core/entities/properties/properties_loader.py @@ -1,5 +1,5 @@ import json -from typing import Callable, Optional, Type +from typing import Any, Callable, Optional, Type import cachetools import yaml @@ -32,7 +32,7 @@ def __init__( self.properties_classes = properties_classes self.properties_class_map = self._load_classes_as_map() - self.extension_loader_lookup:dict[str, Callable[[str], dict]] = { + self.extension_loader_lookup:dict[str, Callable[[str], dict[str, Any]]] = { "json": json.loads, "yaml": yaml.safe_load, "yml": yaml.safe_load, diff --git a/pyproject.toml b/pyproject.toml index 32216c2..401e0ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "dnspython==2.6.1", "email-validator==2.2.0", "exceptiongroup==1.2.2", - "fastapi==0.112.0", + "fastapi==0.115.12", "fastapi-cli==0.0.5", "greenlet==3.0.3", "h11==0.14.0", @@ -28,10 +28,6 @@ dependencies = [ "MarkupSafe==2.1.5", "mdurl==0.1.2", "orjson==3.10.7", - "pydantic==2.8.2", - "pydantic-extra-types==2.9.0", - "pydantic-settings==2.4.0", - "pydantic-core==2.20.1", "Pygments==2.18.0", "python-dotenv==1.0.1", "python-multipart==0.0.9", @@ -39,7 +35,7 @@ dependencies = [ "rich==13.7.1", "shellingham==1.5.4", "sniffio==1.3.1", - "starlette==0.37.2", + "starlette<0.47.0,>=0.40.0", "typer>=0.12.5", "typing-extensions==4.12.2", "ujson==5.10.0", @@ -48,7 +44,8 @@ dependencies = [ "watchfiles==0.23.0", "websockets==12.0", "cachetools>=5.5.0", - "mypy>=1.11.2" + "mypy>=1.11.2", + "pydantic>=2.11.7" ] requires-python = ">=3.10" readme = "README.md" From 09e622aafb3d69dbeeb31b48fdcfd29773cd1613 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Mon, 16 Jun 2025 11:42:22 +0800 Subject: [PATCH 14/42] Security Updates: Dependency Vulnerabilities Fix (#4) --- py_spring_core/__init__.py | 2 +- pyproject.toml | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 2c3ead3..b7c5e11 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -5,4 +5,4 @@ from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.entity_provider import EntityProvider -__version__ = "0.0.8" \ No newline at end of file +__version__ = "0.0.9" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 401e0ee..1e48809 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,13 +16,13 @@ dependencies = [ "fastapi==0.115.12", "fastapi-cli==0.0.5", "greenlet==3.0.3", - "h11==0.14.0", - "httpcore==1.0.5", + "h11==0.16.0", + "httpcore==1.0.9", "httptools==0.6.1", "httpx==0.27.0", "idna==3.7", "itsdangerous==2.2.0", - "Jinja2==3.1.4", + "Jinja2==3.1.6", "loguru==0.7.2", "markdown-it-py==3.0.0", "MarkupSafe==2.1.5", @@ -30,7 +30,7 @@ dependencies = [ "orjson==3.10.7", "Pygments==2.18.0", "python-dotenv==1.0.1", - "python-multipart==0.0.9", + "python-multipart==0.0.10", "PyYAML==6.0.2", "rich==13.7.1", "shellingham==1.5.4", From 8e4d09a3657b0f81166df9a198dde7b133cfb45e Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Mon, 16 Jun 2025 15:13:51 +0800 Subject: [PATCH 15/42] Route Mapping Decorators Implementation (#3) --- py_spring_core/__init__.py | 11 +- .../templates.py | 4 +- .../commons/json_config_repository.py | 12 +- .../commons/type_checking_service.py | 13 +- .../core/application/application_config.py | 7 +- py_spring_core/core/application/commons.py | 1 - .../context/application_context.py | 61 ++++++--- .../core/application/py_spring_application.py | 22 ++- .../entities/controllers/rest_controller.py | 30 ++++- .../entities/controllers/route_mapping.py | 126 ++++++++++++++++++ .../core/entities/entity_provider.py | 2 +- .../entities/properties/properties_loader.py | 2 +- pyproject.toml | 75 +++++------ tests/test_component_features.py | 24 +++- 14 files changed, 294 insertions(+), 96 deletions(-) create mode 100644 py_spring_core/core/entities/controllers/route_mapping.py diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index b7c5e11..ca6de56 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -2,7 +2,14 @@ from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component, ComponentScope from py_spring_core.core.entities.controllers.rest_controller import RestController -from py_spring_core.core.entities.properties.properties import Properties +from py_spring_core.core.entities.controllers.route_mapping import ( + DeleteMapping, + GetMapping, + PatchMapping, + PostMapping, + PutMapping, +) from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.properties.properties import Properties -__version__ = "0.0.9" \ No newline at end of file +__version__ = "0.0.10" diff --git a/py_spring_core/commons/config_file_template_generator/templates.py b/py_spring_core/commons/config_file_template_generator/templates.py index 4d49eba..61a81be 100644 --- a/py_spring_core/commons/config_file_template_generator/templates.py +++ b/py_spring_core/commons/config_file_template_generator/templates.py @@ -5,7 +5,7 @@ "server_config": {"host": "0.0.0.0", "port": 8080, "enabled": True}, "properties_file_path": "./application-properties.json", "loguru_config": {"log_file_path": "./logs/app.log", "log_level": "DEBUG"}, - "type_checking_mode": "strict" + "type_checking_mode": "strict", } -app_properties_template:dict[str, Any] = {} +app_properties_template: dict[str, Any] = {} diff --git a/py_spring_core/commons/json_config_repository.py b/py_spring_core/commons/json_config_repository.py index d77824d..65ef43c 100644 --- a/py_spring_core/commons/json_config_repository.py +++ b/py_spring_core/commons/json_config_repository.py @@ -24,14 +24,14 @@ class JsonConfigRepository(Generic[T]): """ def __init__(self, file_path: str, target_key: Optional[str] = None) -> None: - self.base_model_cls: Type[T] = self.__class__._get_model_cls() + self.base_model_cls: Type[T] = self.__class__._get_model_cls() self.file_path = file_path self.target_key = target_key self._config: T = self._load_config() @classmethod def _get_model_cls(cls) -> Type[T]: - return get_args(cls.__orig_bases__[0])[0] # type: ignore + return get_args(cls.__orig_bases__[0])[0] # type: ignore def get_config(self) -> T: return self._config @@ -41,12 +41,12 @@ def reload_config(self) -> None: def save_config(self) -> None: is_the_same_class = ( - self._config.__class__.__name__ == self.base_model_cls.__name__ - ) + self._config.__class__.__name__ == self.base_model_cls.__name__ + ) if not is_the_same_class: raise TypeError( - f"[BASE MODEL CLASS TYPE MISMATCH] Base model class of current repository: {self.base_model_cls.__name__} mismatch with config class: {self._config.__class__.__name__}" - ) + f"[BASE MODEL CLASS TYPE MISMATCH] Base model class of current repository: {self.base_model_cls.__name__} mismatch with config class: {self._config.__class__.__name__}" + ) with open(self.file_path, "w") as file: file.write(self._config.model_dump_json(indent=4)) diff --git a/py_spring_core/commons/type_checking_service.py b/py_spring_core/commons/type_checking_service.py index 5156df7..f9955cb 100644 --- a/py_spring_core/commons/type_checking_service.py +++ b/py_spring_core/commons/type_checking_service.py @@ -1,18 +1,21 @@ -from enum import Enum import subprocess +from enum import Enum from typing import Optional + from loguru import logger class MypyTypeCheckingError(str, Enum): NoUntypedDefs = "no-untyped-def" + class TypeCheckingErrorr(Exception): ... + class TypeCheckingService: def __init__(self, target_folder: str) -> None: self.target_folder = target_folder - self.checking_command = ['mypy', '--disallow-untyped-defs', self.target_folder] + self.checking_command = ["mypy", "--disallow-untyped-defs", self.target_folder] self.target_typing_errors: list[MypyTypeCheckingError] = [ MypyTypeCheckingError.NoUntypedDefs ] @@ -23,8 +26,8 @@ def type_checking(self) -> Optional[TypeCheckingErrorr]: result = subprocess.run( self.checking_command, capture_output=True, # Captures both stdout and stderr - text=True, # Ensures output is returned as a string - check=False # Avoids raising an exception on non-zero exit code + text=True, # Ensures output is returned as a string + check=False, # Avoids raising an exception on non-zero exit code ) std_err_lines = result.stderr.split("\n") std_out_lines = result.stdout.split("\n") @@ -38,4 +41,4 @@ def type_checking(self) -> Optional[TypeCheckingErrorr]: message += error_message if len(message) == 0: return None - return TypeCheckingErrorr(message) \ No newline at end of file + return TypeCheckingErrorr(message) diff --git a/py_spring_core/core/application/application_config.py b/py_spring_core/core/application/application_config.py index ac9e2c9..0cd55cd 100644 --- a/py_spring_core/core/application/application_config.py +++ b/py_spring_core/core/application/application_config.py @@ -1,9 +1,8 @@ from enum import Enum + from pydantic import BaseModel, ConfigDict, Field -from py_spring_core.commons.json_config_repository import ( - JsonConfigRepository, -) +from py_spring_core.commons.json_config_repository import JsonConfigRepository from py_spring_core.core.application.loguru_config import LoguruConfig @@ -21,6 +20,7 @@ class ServerConfig(BaseModel): port: int enabled: bool = Field(default=True) + class ApplicationConfig(BaseModel): """ Represents the configuration for the application. @@ -40,6 +40,7 @@ class ApplicationConfig(BaseModel): properties_file_path: str loguru_config: LoguruConfig + class ApplicationConfigRepository(JsonConfigRepository[ApplicationConfig]): """ Represents a repository for managing the application configuration, which is stored in a JSON file. diff --git a/py_spring_core/core/application/commons.py b/py_spring_core/core/application/commons.py index 60175df..eb0b11b 100644 --- a/py_spring_core/core/application/commons.py +++ b/py_spring_core/core/application/commons.py @@ -3,5 +3,4 @@ from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.properties.properties import Properties - AppEntities = Component | RestController | BeanCollection | Properties diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index cbd9b5c..cc16f6d 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -1,6 +1,16 @@ from abc import ABC from inspect import isclass -from typing import Annotated, Callable, Mapping, Optional, Type, TypeVar, cast, get_origin, get_args +from typing import ( + Annotated, + Callable, + Mapping, + Optional, + Type, + TypeVar, + cast, + get_args, + get_origin, +) from loguru import logger from pydantic import BaseModel @@ -20,7 +30,6 @@ from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.properties.properties_loader import _PropertiesLoader - T = TypeVar("T", bound=AppEntities) PT = TypeVar("PT", bound=Properties) @@ -80,8 +89,10 @@ def as_view(self) -> ApplicationContextView: self.singleton_component_instance_container.keys() ), ) - - def _determine_target_cls_name(self, component_cls: Type[T], qualifier: Optional[str]) -> str: + + def _determine_target_cls_name( + self, component_cls: Type[T], qualifier: Optional[str] + ) -> str: """ Determine the target class name for a given component class. This method handles the following cases: @@ -94,27 +105,28 @@ def _determine_target_cls_name(self, component_cls: Type[T], qualifier: Optional if qualifier is not None: return qualifier - + # If it's not an ABC, return its name directly if not issubclass(component_cls, ABC): return component_cls.get_name() - + # If it's an ABC but has implementations, return its name directly if not component_cls.__abstractmethods__: return component_cls.get_name() - + # For abstract classes that need implementations subclasses = component_cls.__subclasses__() if len(subclasses) == 0: raise ValueError( f"[ABSTRACT CLASS ERROR] Abstract class {component_cls.__name__} has no subclasses" ) - - + # Fall back to first subclass if no primary component exists return subclasses[0].get_name() - def get_component(self, component_cls: Type[T], qualifier: Optional[str]) -> Optional[T]: + def get_component( + self, component_cls: Type[T], qualifier: Optional[str] + ) -> Optional[T]: if not issubclass(component_cls, (Component, ABC)): return None @@ -129,7 +141,7 @@ def get_component(self, component_cls: Type[T], qualifier: Optional[str]) -> Opt optional_instance = self.singleton_component_instance_container.get( target_cls_name ) - return cast(T, optional_instance) + return cast(T, optional_instance) case ComponentScope.Prototype: prototype_instance = component_cls() @@ -172,7 +184,9 @@ def register_component(self, component_cls: Type[Component]) -> None: ) component_cls_name = component_cls.get_name() if component_cls_name in self.component_cls_container: - raise ValueError(f"[COMPONENT REGISTRATION ERROR] Component: {component_cls_name} already registered") + raise ValueError( + f"[COMPONENT REGISTRATION ERROR] Component: {component_cls_name} already registered" + ) self.component_cls_container[component_cls_name] = component_cls def register_controller(self, controller_cls: Type[RestController]) -> None: @@ -205,8 +219,6 @@ def register_entity_provider(self, provider: EntityProvider) -> None: elif issubclass(entity_cls, Properties): self.register_properties(entity_cls) - - def register_properties(self, properties_cls: Type[Properties]) -> None: if not issubclass(properties_cls, Properties): raise TypeError( @@ -246,18 +258,24 @@ def load_properties(self) -> None: self.singleton_properties_instance_container ) - def init_singleton_component(self, component_cls: Type[Component], component_cls_name: str) -> Optional[Component]: + def init_singleton_component( + self, component_cls: Type[Component], component_cls_name: str + ) -> Optional[Component]: instance: Optional[Component] = None try: instance = component_cls() except Exception as error: unable_to_init_component_error_prefix = "Can't instantiate abstract class" if unable_to_init_component_error_prefix in str(error): - logger.warning(f"[INITIALIZING SINGLETON COMPONENT ERROR] Skip initializing singleton component: {component_cls_name} because it is an abstract class") + logger.warning( + f"[INITIALIZING SINGLETON COMPONENT ERROR] Skip initializing singleton component: {component_cls_name} because it is an abstract class" + ) return - logger.error(f"[INITIALIZING SINGLETON COMPONENT ERROR] Error initializing singleton component: {component_cls_name} with error: {error}") + logger.error( + f"[INITIALIZING SINGLETON COMPONENT ERROR] Error initializing singleton component: {component_cls_name} with error: {error}" + ) raise error - + return instance def init_ioc_container(self) -> None: @@ -320,7 +338,6 @@ def _inject_entity_dependencies(self, entity: Type[AppEntities]) -> None: if not isclass(annotated_entity_cls): continue - if issubclass(annotated_entity_cls, Properties): optional_properties = self.get_properties(annotated_entity_cls) if optional_properties is None: @@ -330,9 +347,9 @@ def _inject_entity_dependencies(self, entity: Type[AppEntities]) -> None: setattr(entity, attr_name, optional_properties) continue - entity_getters: list[Callable[[Type[AppEntities], Optional[str]], Optional[AppEntities]]] = [ - self.get_component, self.get_bean - ] + entity_getters: list[ + Callable[[Type[AppEntities], Optional[str]], Optional[AppEntities]] + ] = [self.get_component, self.get_bean] for getter in entity_getters: optional_entity = getter(annotated_entity_cls, qualifier) diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index cf956ce..ea7d935 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -4,17 +4,17 @@ import uvicorn from fastapi import APIRouter, FastAPI from loguru import logger -from pydantic import BaseModel, ConfigDict -from py_spring_core.commons.type_checking_service import TypeCheckingService -from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.commons.class_scanner import ClassScanner from py_spring_core.commons.config_file_template_generator.config_file_template_generator import ( ConfigFileTemplateGenerator, ) from py_spring_core.commons.file_path_scanner import FilePathScanner -from py_spring_core.core.application.application_config import ApplicationConfigRepository +from py_spring_core.commons.type_checking_service import TypeCheckingService +from py_spring_core.core.application.application_config import ( + ApplicationConfigRepository, +) +from py_spring_core.core.application.commons import AppEntities from py_spring_core.core.application.context.application_context import ( ApplicationContext, ) @@ -24,6 +24,8 @@ from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component, ComponentLifeCycle from py_spring_core.core.entities.controllers.rest_controller import RestController +from py_spring_core.core.entities.controllers.route_mapping import RouteMapping +from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.core.entities.properties.properties import Properties @@ -83,7 +85,9 @@ def __init__( BeanCollection: self._handle_register_bean_collection, Properties: self._handle_register_properties, } - self.type_checking_service = TypeCheckingService(self.app_config.app_src_target_dir) + self.type_checking_service = TypeCheckingService( + self.app_config.app_src_target_dir + ) def __configure_logging(self): """Applies the logging configuration using Loguru.""" @@ -180,7 +184,11 @@ def _handle_singleton_components_life_cycle( def __init_controllers(self) -> None: controllers = self.app_context.get_controller_instances() for controller in controllers: - controller.register_routes() + name = controller.__class__.__name__ + routes = RouteMapping.routes.get(name, None) + if routes is None: + continue + controller._register_routes(routes) router = controller.get_router() self.fastapi.include_router(router) controller.register_middlewares() diff --git a/py_spring_core/core/entities/controllers/rest_controller.py b/py_spring_core/core/entities/controllers/rest_controller.py index 20d81f8..98026b2 100644 --- a/py_spring_core/core/entities/controllers/rest_controller.py +++ b/py_spring_core/core/entities/controllers/rest_controller.py @@ -1,4 +1,7 @@ from fastapi import APIRouter, FastAPI +from functools import partial + +from py_spring_core.core.entities.controllers.route_mapping import RouteRegistration class RestController: @@ -20,7 +23,32 @@ class RestController: class Config: prefix: str = "" - def register_routes(self) -> None: ... + def _register_routes(self, routes: list[RouteRegistration]) -> None: + for route in routes: + bound_method = partial(route.func, self) + self.router.add_api_route( + path=route.path, + endpoint=bound_method, + methods=[route.method.value], + response_model=route.response_model, + status_code=route.status_code, + tags=route.tags, + dependencies=route.dependencies, + summary=route.summary, + description=route.description, + response_description=route.response_description, + responses=route.responses, + deprecated=route.deprecated, + operation_id=route.operation_id, + response_model_include=route.response_model_include, + response_model_exclude=route.response_model_exclude, + response_model_by_alias=route.response_model_by_alias, + response_model_exclude_unset=route.response_model_exclude_unset, + response_model_exclude_defaults=route.response_model_exclude_defaults, + response_model_exclude_none=route.response_model_exclude_none, + include_in_schema=route.include_in_schema, + name=route.name, + ) def register_middlewares(self) -> None: ... diff --git a/py_spring_core/core/entities/controllers/route_mapping.py b/py_spring_core/core/entities/controllers/route_mapping.py new file mode 100644 index 0000000..c49eca1 --- /dev/null +++ b/py_spring_core/core/entities/controllers/route_mapping.py @@ -0,0 +1,126 @@ +from enum import Enum +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Set, Union + +from pydantic import BaseModel + + +class HTTPMethod(str, Enum): + GET = "GET" + POST = "POST" + PUT = "PUT" + DELETE = "DELETE" + PATCH = "PATCH" + + +class RouteRegistration(BaseModel): + class_name: str + method: HTTPMethod + path: str + func: Callable + response_model: Any = None + status_code: Optional[int] = None + tags: Optional[List[Union[str, Enum]]] = None + dependencies: Optional[List[Any]] = None + summary: Optional[str] = None + description: Optional[str] = None + response_description: str = "Successful Response" + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None + deprecated: Optional[bool] = None + operation_id: Optional[str] = None + response_model_include: Optional[Set[str]] = None + response_model_exclude: Optional[Set[str]] = None + response_model_by_alias: bool = True + response_model_exclude_unset: bool = False + response_model_exclude_defaults: bool = False + response_model_exclude_none: bool = False + include_in_schema: bool = True + name: Optional[str] = None + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, RouteRegistration): + return False + return self.method == other.method and self.path == other.path + + def __hash__(self) -> int: + return hash((self.method, self.path)) + + +class RouteMapping: + routes: dict[str, set[RouteRegistration]] = {} + + @classmethod + def register_route(cls, route_registration: RouteRegistration) -> None: + optional_routes = cls.routes.get(route_registration.class_name, None) + if optional_routes is None: + cls.routes[route_registration.class_name] = set() + cls.routes[route_registration.class_name].add(route_registration) + + +def _create_route_decorator(method: HTTPMethod): + def decorator_factory( + path: str, + *, + response_model: Any = None, + status_code: Optional[int] = None, + tags: Optional[List[Union[str, Enum]]] = None, + dependencies: Optional[List[Any]] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + response_description: str = "Successful Response", + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + deprecated: Optional[bool] = None, + operation_id: Optional[str] = None, + response_model_include: Optional[Set[str]] = None, + response_model_exclude: Optional[Set[str]] = None, + response_model_by_alias: bool = True, + response_model_exclude_unset: bool = False, + response_model_exclude_defaults: bool = False, + response_model_exclude_none: bool = False, + include_in_schema: bool = True, + name: Optional[str] = None, + ): + def decorator(func: Callable): + class_name = func.__qualname__.split(".")[0] + route_registration = RouteRegistration( + class_name=class_name, + method=method, + path=path, + func=func, + response_model=response_model, + status_code=status_code, + tags=tags, + dependencies=dependencies, + summary=summary or func.__name__, + description=description or func.__doc__, + response_description=response_description, + responses=responses, + deprecated=deprecated, + operation_id=operation_id, + response_model_include=response_model_include, + response_model_exclude=response_model_exclude, + response_model_by_alias=response_model_by_alias, + response_model_exclude_unset=response_model_exclude_unset, + response_model_exclude_defaults=response_model_exclude_defaults, + response_model_exclude_none=response_model_exclude_none, + include_in_schema=include_in_schema, + name=name, + ) + RouteMapping.register_route(route_registration) + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any): + return func(*args, **kwargs) + + return wrapper + + return decorator + + return decorator_factory + + +GetMapping = _create_route_decorator(HTTPMethod.GET) +PostMapping = _create_route_decorator(HTTPMethod.POST) +PutMapping = _create_route_decorator(HTTPMethod.PUT) +DeleteMapping = _create_route_decorator(HTTPMethod.DELETE) +PatchMapping = _create_route_decorator(HTTPMethod.PATCH) diff --git a/py_spring_core/core/entities/entity_provider.py b/py_spring_core/core/entities/entity_provider.py index 5481add..aeb47e5 100644 --- a/py_spring_core/core/entities/entity_provider.py +++ b/py_spring_core/core/entities/entity_provider.py @@ -1,4 +1,4 @@ -from dataclasses import field, dataclass +from dataclasses import dataclass, field from typing import Any, Optional, Type from py_spring_core.core.application.commons import AppEntities diff --git a/py_spring_core/core/entities/properties/properties_loader.py b/py_spring_core/core/entities/properties/properties_loader.py index 525ae95..983d803 100644 --- a/py_spring_core/core/entities/properties/properties_loader.py +++ b/py_spring_core/core/entities/properties/properties_loader.py @@ -32,7 +32,7 @@ def __init__( self.properties_classes = properties_classes self.properties_class_map = self._load_classes_as_map() - self.extension_loader_lookup:dict[str, Callable[[str], dict[str, Any]]] = { + self.extension_loader_lookup: dict[str, Callable[[str], dict[str, Any]]] = { "json": json.loads, "yaml": yaml.safe_load, "yml": yaml.safe_load, diff --git a/pyproject.toml b/pyproject.toml index 1e48809..6506f02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,48 +6,47 @@ authors = [ {name = "William Chen", email = "OW6201231@gmail.com"}, ] dependencies = [ - "annotated-types==0.7.0", - "anyio==4.4.0", - "certifi==2024.7.4", - "click==8.1.7", - "dnspython==2.6.1", - "email-validator==2.2.0", - "exceptiongroup==1.2.2", - "fastapi==0.115.12", - "fastapi-cli==0.0.5", - "greenlet==3.0.3", - "h11==0.16.0", - "httpcore==1.0.9", - "httptools==0.6.1", - "httpx==0.27.0", - "idna==3.7", - "itsdangerous==2.2.0", - "Jinja2==3.1.6", - "loguru==0.7.2", - "markdown-it-py==3.0.0", - "MarkupSafe==2.1.5", - "mdurl==0.1.2", - "orjson==3.10.7", - "Pygments==2.18.0", - "python-dotenv==1.0.1", - "python-multipart==0.0.10", - "PyYAML==6.0.2", - "rich==13.7.1", - "shellingham==1.5.4", - "sniffio==1.3.1", - "starlette<0.47.0,>=0.40.0", + "annotated-types>=0.7.0", + "anyio>=4.4.0", + "certifi>=2023.7.22", + "click>=8.1.7", + "dnspython>=2.6.1", + "email-validator>=2.2.0", + "exceptiongroup>=1.2.2", + "fastapi>=0.115.12", + "fastapi-cli>=0.0.5", + "greenlet>=3.0.3", + "h11>=0.16.0", + "httpcore>=1.0.9", + "httptools>=0.6.1", + "httpx>=0.27.0", + "idna>=3.7", + "itsdangerous>=2.2.0", + "Jinja2>=3.1.6", + "loguru>=0.7.2", + "markdown-it-py>=3.0.0", + "MarkupSafe>=2.1.5", + "mdurl>=0.1.2", + "orjson>=3.10.7", + "Pygments>=2.18.0", + "python-dotenv>=1.0.1", + "python-multipart>=0.0.18", + "PyYAML>=6.0.2", + "rich>=13.7.1", + "shellingham>=1.5.4", + "sniffio>=1.3.1", + "starlette>=0.40.0,<0.47.0", "typer>=0.12.5", - "typing-extensions==4.12.2", - "ujson==5.10.0", - "uvicorn==0.30.5", - "uvloop==0.19.0", - "watchfiles==0.23.0", - "websockets==12.0", + "typing-extensions>=4.12.2", + "ujson>=5.10.0", + "uvicorn>=0.30.5", + "uvloop>=0.19.0", + "watchfiles>=0.23.0", + "websockets>=12.0", "cachetools>=5.5.0", - "mypy>=1.11.2", "pydantic>=2.11.7" ] -requires-python = ">=3.10" +requires-python = ">=3.10,<3.13" readme = "README.md" license = {text = "MIT"} diff --git a/tests/test_component_features.py b/tests/test_component_features.py index 116a98a..b41b217 100644 --- a/tests/test_component_features.py +++ b/tests/test_component_features.py @@ -1,13 +1,13 @@ from abc import ABC -import pytest from typing import Annotated +import pytest + from py_spring_core.core.application.context.application_context import ( ApplicationContext, - ApplicationContextConfig + ApplicationContextConfig, ) -from py_spring_core.core.entities.component import Component -from py_spring_core.core.entities.component import ComponentScope +from py_spring_core.core.entities.component import Component, ComponentScope class TestComponentFeatures: @@ -32,6 +32,7 @@ def test_qualifier_based_injection(self, app_context: ApplicationContext): The test creates an abstract service with two implementations and verifies that each can be injected into a consumer using appropriate qualifiers. """ + # Define abstract base class class AbstractService(Component): class Config: @@ -92,6 +93,7 @@ def test_duplicate_component_registration(self, app_context: ApplicationContext) The test attempts to register the same component twice and verifies that an appropriate error is raised. """ + # Define a component class TestService(Component): class Config: @@ -106,7 +108,10 @@ def process(self) -> str: app_context.init_ioc_container() # Attempt to register same component again should raise error - with pytest.raises(ValueError, match="\\[COMPONENT REGISTRATION ERROR\\] Component: TestService already registered"): + with pytest.raises( + ValueError, + match="\\[COMPONENT REGISTRATION ERROR\\] Component: TestService already registered", + ): app_context.register_component(TestService) app_context.init_ioc_container() @@ -122,6 +127,7 @@ def test_component_name_override(self, app_context: ApplicationContext): The test registers a component with a custom name and verifies that it is correctly stored in the container. """ + # Define component with custom name class TestService(Component): class Config: @@ -151,6 +157,7 @@ def test_qualifier_with_invalid_component(self, app_context: ApplicationContext) The test attempts to inject a component using a non-existent qualifier and verifies that an appropriate error is raised. """ + # Define abstract base class class AbstractService(Component): class Config: @@ -185,5 +192,8 @@ def post_construct(self) -> None: app_context.init_ioc_container() # Initialize the consumer component # Attempting to inject with invalid qualifier should raise error - with pytest.raises(ValueError, match="\\[DEPENDENCY INJECTION FAILED\\] Fail to inject dependency for attribute: service with dependency: AbstractService with qualifier: NonExistentService"): - app_context.inject_dependencies_for_app_entities() \ No newline at end of file + with pytest.raises( + ValueError, + match="\\[DEPENDENCY INJECTION FAILED\\] Fail to inject dependency for attribute: service with dependency: AbstractService with qualifier: NonExistentService", + ): + app_context.inject_dependencies_for_app_entities() From 2e67669e2518b550bf9a31e9c7f0c20121a5b688 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Wed, 18 Jun 2025 21:40:43 +0800 Subject: [PATCH 16/42] Event System Implementation (#5) --- .gitignore | 8 +- py_spring_core/__init__.py | 5 +- .../core/application/py_spring_application.py | 27 ++++- .../entities/controllers/rest_controller.py | 3 +- .../application_context_required.py | 34 ++++++ .../application_event_handler_registry.py | 107 ++++++++++++++++++ .../event/application_event_publisher.py | 30 +++++ py_spring_core/event/commons.py | 7 ++ tests/test_rest_controller.py | 29 ----- 9 files changed, 213 insertions(+), 37 deletions(-) create mode 100644 py_spring_core/core/interfaces/application_context_required.py create mode 100644 py_spring_core/event/application_event_handler_registry.py create mode 100644 py_spring_core/event/application_event_publisher.py create mode 100644 py_spring_core/event/commons.py delete mode 100644 tests/test_rest_controller.py diff --git a/.gitignore b/.gitignore index efa407c..280fb4c 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,10 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ + +src/ +app-config.json +application-properties.json +main.py +pdm.lock \ No newline at end of file diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index ca6de56..b3bbf5d 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -11,5 +11,8 @@ ) from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.core.entities.properties.properties import Properties +from py_spring_core.event.application_event_publisher import ApplicationEventPublisher +from py_spring_core.event.commons import ApplicationEvent +from py_spring_core.event.application_event_handler_registry import EventListener -__version__ = "0.0.10" +__version__ = "0.0.11" diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index ea7d935..4037f96 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -27,6 +27,9 @@ from py_spring_core.core.entities.controllers.route_mapping import RouteMapping from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.core.entities.properties.properties import Properties +from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired +from py_spring_core.event.application_event_handler_registry import ApplicationEventHandlerRegistry +from py_spring_core.event.application_event_publisher import ApplicationEventPublisher class PySpringApplication: @@ -103,9 +106,15 @@ def __configure_logging(self): retention=config.log_retention, ) - def _scan_classes_for_project(self) -> None: + def _get_system_managed_classes(self) -> Iterable[Type[Component]]: + return [ + ApplicationEventPublisher, + ApplicationEventHandlerRegistry + ] + + def _scan_classes_for_project(self) -> Iterable[Type[object]]: self.app_class_scanner.scan_classes_for_file_paths() - self.scanned_classes = self.app_class_scanner.get_classes() + return self.app_class_scanner.get_classes() def _register_all_entities_from_providers(self) -> None: for provider in self.entity_providers: @@ -155,10 +164,19 @@ def _init_providers(self, providers: Iterable[EntityProvider]) -> None: for provider in providers: provider.provider_init() + def _inject_application_context_to_context_required(self, classes: Iterable[Type[object]]) -> None: + for cls in classes: + if not issubclass(cls, ApplicationContextRequired): + continue + cls.set_application_context(self.app_context) + def __init_app(self) -> None: - self._scan_classes_for_project() + scanned_classes = self._scan_classes_for_project() + system_managed_classes = self._get_system_managed_classes() + classes_to_inject = [*scanned_classes, *system_managed_classes] + self._inject_application_context_to_context_required(classes_to_inject) self._register_all_entities_from_providers() - self._register_app_entities(self.scanned_classes) + self._register_app_entities(classes_to_inject) self._register_entity_providers(self.entity_providers) self.app_context.load_properties() self.app_context.init_ioc_container() @@ -166,7 +184,6 @@ def __init_app(self) -> None: self.app_context.set_all_file_paths(self.target_dir_absolute_file_paths) self.app_context.validate_entity_providers() # after injecting all deps, lifecycle (init) can be called - self._init_providers(self.entity_providers) self._handle_singleton_components_life_cycle(ComponentLifeCycle.Init) diff --git a/py_spring_core/core/entities/controllers/rest_controller.py b/py_spring_core/core/entities/controllers/rest_controller.py index 98026b2..8fa9aaa 100644 --- a/py_spring_core/core/entities/controllers/rest_controller.py +++ b/py_spring_core/core/entities/controllers/rest_controller.py @@ -1,3 +1,4 @@ +from typing import Iterable from fastapi import APIRouter, FastAPI from functools import partial @@ -23,7 +24,7 @@ class RestController: class Config: prefix: str = "" - def _register_routes(self, routes: list[RouteRegistration]) -> None: + def _register_routes(self, routes: Iterable[RouteRegistration]) -> None: for route in routes: bound_method = partial(route.func, self) self.router.add_api_route( diff --git a/py_spring_core/core/interfaces/application_context_required.py b/py_spring_core/core/interfaces/application_context_required.py new file mode 100644 index 0000000..ea91abd --- /dev/null +++ b/py_spring_core/core/interfaces/application_context_required.py @@ -0,0 +1,34 @@ +from typing import Optional + +from py_spring_core.core.application.context.application_context import ApplicationContext + + +class ApplicationContextRequired: + """ + A mixin class that provides access to the ApplicationContext for classes that need it. + + This class serves as a base for components that require access to the ApplicationContext. + It provides class-level methods to set and retrieve the ApplicationContext instance. + + Usage: + class MyComponent(ApplicationContextRequired): + def some_method(self): + context = self.get_application_context() + # Use the context... + + Note: + The ApplicationContext must be set before attempting to retrieve it, + otherwise a RuntimeError will be raised. + """ + _app_context: Optional[ApplicationContext] = None + + + @classmethod + def set_application_context(cls, application_context: ApplicationContext) -> None: + cls._app_context = application_context + + @classmethod + def get_application_context(cls) -> ApplicationContext: + if cls._app_context is None: + raise RuntimeError("ApplicationContext is not set") + return cls._app_context \ No newline at end of file diff --git a/py_spring_core/event/application_event_handler_registry.py b/py_spring_core/event/application_event_handler_registry.py new file mode 100644 index 0000000..0a1fb9e --- /dev/null +++ b/py_spring_core/event/application_event_handler_registry.py @@ -0,0 +1,107 @@ + +from threading import Thread +from typing import Callable, Type + +from loguru import logger +from pydantic import BaseModel + +from py_spring_core.core.entities.component import Component +from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired +from py_spring_core.event.commons import ApplicationEvent, EventQueue + +EventHandlerT = Callable[[Component, ApplicationEvent], None] + +def EventListener(event_type: Type[ApplicationEvent]) -> Callable: + """ + The EventListener decorator is used to register an event handler for an application event. + It is responsible for binding an event handler to a component and a function. + """ + def decorator(func: EventHandlerT) -> None: + if not issubclass(event_type, ApplicationEvent): + raise ValueError(f"Event type must be a subclass of ApplicationEvent") + + ApplicationEventHandlerRegistry.register_event_handler(event_type, func) + return decorator + + +class EventHandler(BaseModel): + """ + The EventHandler class is a model that represents an event handler for an application event. + It is responsible for binding an event handler to a component and a function. + """ + class_name: str + func_name: str + event_type: Type[ApplicationEvent] + func: EventHandlerT + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EventHandler): + return False + return self.class_name == other.class_name and self.func_name == other.func_name + + def __hash__(self) -> int: + return hash((self.class_name, self.func_name)) + + +class ApplicationEventHandlerRegistry(Component, ApplicationContextRequired): + """ + The ApplicationEventHandlerRegistry is a component that registers event handlers for application events. + It is responsible for binding event handlers to their corresponding components and handling event messages. + + The class performs the following key tasks: + - Registers event handlers for application events + - Binds event handlers to their corresponding components + """ + _class_event_handlers: dict[str, list[EventHandler]] = {} + def __init__(self) -> None: + self._event_handlers: dict[str, list[EventHandler]] = {} + self._event_message_queue = EventQueue.queue + + def post_construct(self) -> None: + logger.info("Initializing event handlers...") + self._init_event_handlers() + logger.info("Starting event message handler thread...") + Thread(target= self._handle_messages).start() + + def _init_event_handlers(self) -> None: + app_context = self.get_application_context() + # get_name might be different from the class name, so we use the class name for function binding + self.component_instance_map = { + component.__class__.__name__: component + for component in app_context.get_singleton_component_instances() + } + self._event_handlers = self._class_event_handlers + + + @classmethod + def register_event_handler(cls, event_type: Type[ApplicationEvent], handler: EventHandlerT): + event_name = event_type.__name__ + func_name_parts = handler.__qualname__.split(".") + if len(func_name_parts) != 2: + raise ValueError(f"Handler must be a member function of a class") + class_name, func_name = func_name_parts + if event_name not in cls._class_event_handlers: + cls._class_event_handlers[event_name] = [] + event_handler = EventHandler(class_name=class_name, func_name=func_name, event_type=event_type, func=handler) + if event_handler not in cls._class_event_handlers[event_name]: + cls._class_event_handlers[event_name].append(event_handler) + + def get_event_handlers(self, event_type: Type[ApplicationEvent]) -> list[EventHandler]: + event_name = event_type.__name__ + handlers = self._event_handlers.get(event_name, []) + return handlers + + def _handle_messages(self) -> None: + logger.info("Event message handler thread started...") + while True: + message = self._event_message_queue.get() + for handler in self.get_event_handlers(message.__class__): + try: + optional_instance = self.component_instance_map.get(handler.class_name, None) + if optional_instance is None: + logger.error(f"Component instance not found for handler: {handler.class_name}") + continue + handler.func(optional_instance, message) + except Exception as error: + logger.error(f"Error handling event: {error}") + \ No newline at end of file diff --git a/py_spring_core/event/application_event_publisher.py b/py_spring_core/event/application_event_publisher.py new file mode 100644 index 0000000..efc3115 --- /dev/null +++ b/py_spring_core/event/application_event_publisher.py @@ -0,0 +1,30 @@ +from typing import TypeVar + + +from py_spring_core.core.entities.component import Component +from py_spring_core.event.application_event_handler_registry import ApplicationEvent, ApplicationEventHandlerRegistry +from py_spring_core.event.commons import EventQueue + +T = TypeVar("T", bound=ApplicationEvent) + + + + +class ApplicationEventPublisher(Component): + """ + The ApplicationEventPublisher is a component that publishes application events. + It is responsible for publishing application events to the event message queue. + + The class performs the following key tasks: + - Publishes application events to the event message queue + """ + def __init__(self): + self.event_message_queue = EventQueue.queue + + def publish(self, event: ApplicationEvent) -> None: + self.event_message_queue.put(event) + + + + + \ No newline at end of file diff --git a/py_spring_core/event/commons.py b/py_spring_core/event/commons.py new file mode 100644 index 0000000..bd76bc4 --- /dev/null +++ b/py_spring_core/event/commons.py @@ -0,0 +1,7 @@ +from queue import Queue + +from pydantic import BaseModel + +class ApplicationEvent(BaseModel): ... +class EventQueue: + queue: Queue[ApplicationEvent] = Queue() \ No newline at end of file diff --git a/tests/test_rest_controller.py b/tests/test_rest_controller.py deleted file mode 100644 index c2c2669..0000000 --- a/tests/test_rest_controller.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest -from fastapi import APIRouter, FastAPI - -from py_spring_core.core.entities.controllers.rest_controller import RestController - - -class TestRestController: - @pytest.fixture - def app(self) -> FastAPI: - return FastAPI() - - @pytest.fixture - def router(self) -> APIRouter: - return APIRouter() - - @pytest.fixture - def test_controller(self, app: FastAPI, router: APIRouter) -> RestController: - class TestController(RestController): - def register_routes(self): - self.router.add_api_route("/test", lambda: "test") - - TestController.app = app - TestController.router = router - - return TestController() - - def test_register_routes_successfully(self, test_controller: RestController): - test_controller.register_routes() - assert len(test_controller.router.routes) == 1 From bca6e20553177c2b1ab2749ea9c8465572ca7526 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Thu, 19 Jun 2025 00:03:33 +0800 Subject: [PATCH 17/42] chore: Update version to 0.0.12 and refactor application event handler registry (#6) --- py_spring_core/__init__.py | 2 +- py_spring_core/core/application/py_spring_application.py | 6 ------ py_spring_core/event/application_event_handler_registry.py | 4 ++-- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index b3bbf5d..43422dd 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -15,4 +15,4 @@ from py_spring_core.event.commons import ApplicationEvent from py_spring_core.event.application_event_handler_registry import EventListener -__version__ = "0.0.11" +__version__ = "0.0.12" diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 4037f96..1ecf62d 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -116,11 +116,6 @@ def _scan_classes_for_project(self) -> Iterable[Type[object]]: self.app_class_scanner.scan_classes_for_file_paths() return self.app_class_scanner.get_classes() - def _register_all_entities_from_providers(self) -> None: - for provider in self.entity_providers: - entities = provider.get_entities() - self._register_app_entities(entities) - def _register_app_entities(self, classes: Iterable[Type[object]]) -> None: for _cls in classes: for _target_cls, handler in self.classes_with_handlers.items(): @@ -175,7 +170,6 @@ def __init_app(self) -> None: system_managed_classes = self._get_system_managed_classes() classes_to_inject = [*scanned_classes, *system_managed_classes] self._inject_application_context_to_context_required(classes_to_inject) - self._register_all_entities_from_providers() self._register_app_entities(classes_to_inject) self._register_entity_providers(self.entity_providers) self.app_context.load_properties() diff --git a/py_spring_core/event/application_event_handler_registry.py b/py_spring_core/event/application_event_handler_registry.py index 0a1fb9e..149c75f 100644 --- a/py_spring_core/event/application_event_handler_registry.py +++ b/py_spring_core/event/application_event_handler_registry.py @@ -1,6 +1,6 @@ from threading import Thread -from typing import Callable, Type +from typing import Callable, ClassVar, Type from loguru import logger from pydantic import BaseModel @@ -52,7 +52,7 @@ class ApplicationEventHandlerRegistry(Component, ApplicationContextRequired): - Registers event handlers for application events - Binds event handlers to their corresponding components """ - _class_event_handlers: dict[str, list[EventHandler]] = {} + _class_event_handlers: ClassVar[dict[str, list[EventHandler]]] = {} def __init__(self) -> None: self._event_handlers: dict[str, list[EventHandler]] = {} self._event_message_queue = EventQueue.queue From 1a49fd3a75c5133ffdec784c086b710a377d72df Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Thu, 19 Jun 2025 09:08:07 +0800 Subject: [PATCH 18/42] chore: Update version to 0.0.13 and refactor entity provider handling (#7) --- py_spring_core/__init__.py | 3 +- .../context/application_context.py | 12 -------- .../core/application/py_spring_application.py | 29 +++++++++++-------- .../entities/controllers/rest_controller.py | 4 ++- .../core/entities/entity_provider.py | 2 +- tests/test_entity_provider.py | 2 +- 6 files changed, 24 insertions(+), 28 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 43422dd..33bf03f 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -11,8 +11,9 @@ ) from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.core.entities.properties.properties import Properties +from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent from py_spring_core.event.application_event_handler_registry import EventListener -__version__ = "0.0.12" +__version__ = "0.0.13" diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index cc16f6d..3e5262a 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -207,18 +207,6 @@ def register_bean_collection(self, bean_cls: Type[BeanCollection]) -> None: bean_name = bean_cls.get_name() self.bean_collection_cls_container[bean_name] = bean_cls - def register_entity_provider(self, provider: EntityProvider) -> None: - self.providers.append(provider) - for entity_cls in provider.get_entities(): - if issubclass(entity_cls, Component): - self.register_component(entity_cls) - elif issubclass(entity_cls, RestController): - self.register_controller(entity_cls) - elif issubclass(entity_cls, BeanCollection): - self.register_bean_collection(entity_cls) - elif issubclass(entity_cls, Properties): - self.register_properties(entity_cls) - def register_properties(self, properties_cls: Type[Properties]) -> None: if not issubclass(properties_cls, Properties): raise TypeError( diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 1ecf62d..375c9d3 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -123,12 +123,12 @@ def _register_app_entities(self, classes: Iterable[Type[object]]) -> None: continue handler(_cls) - def _register_entity_providers( - self, entity_providers: Iterable[EntityProvider] - ) -> None: + def _get_all_entities_from_entity_providers(self, entity_providers: Iterable[EntityProvider]) -> Iterable[Type[AppEntities]]: + entities: list[Type[AppEntities]] = [] for provider in entity_providers: - self.app_context.register_entity_provider(provider) - provider.set_context(self.app_context) + entities.extend(provider.get_entities()) + + return entities def _handle_register_component(self, _cls: Type[Component]) -> None: self.app_context.register_component(_cls) @@ -165,13 +165,19 @@ def _inject_application_context_to_context_required(self, classes: Iterable[Type continue cls.set_application_context(self.app_context) - def __init_app(self) -> None: + def _prepare_injected_classes(self) -> Iterable[Type[object]]: scanned_classes = self._scan_classes_for_project() system_managed_classes = self._get_system_managed_classes() - classes_to_inject = [*scanned_classes, *system_managed_classes] + provider_entities = self._get_all_entities_from_entity_providers(self.entity_providers) + provider_classes = [provider.__class__ for provider in self.entity_providers] + # providers typically requires app context, so add to classess to inject + classes_to_inject = [*scanned_classes, *system_managed_classes, *provider_entities, *provider_classes] + return classes_to_inject + + def __init_app(self) -> None: + classes_to_inject = self._prepare_injected_classes() self._inject_application_context_to_context_required(classes_to_inject) self._register_app_entities(classes_to_inject) - self._register_entity_providers(self.entity_providers) self.app_context.load_properties() self.app_context.init_ioc_container() self.app_context.inject_dependencies_for_app_entities() @@ -196,10 +202,9 @@ def __init_controllers(self) -> None: controllers = self.app_context.get_controller_instances() for controller in controllers: name = controller.__class__.__name__ - routes = RouteMapping.routes.get(name, None) - if routes is None: - continue - controller._register_routes(routes) + routes = RouteMapping.routes.get(name, set()) + controller.post_construct() + controller._register_decorated_routes(routes) router = controller.get_router() self.fastapi.include_router(router) controller.register_middlewares() diff --git a/py_spring_core/core/entities/controllers/rest_controller.py b/py_spring_core/core/entities/controllers/rest_controller.py index 8fa9aaa..52ce6a5 100644 --- a/py_spring_core/core/entities/controllers/rest_controller.py +++ b/py_spring_core/core/entities/controllers/rest_controller.py @@ -24,7 +24,9 @@ class RestController: class Config: prefix: str = "" - def _register_routes(self, routes: Iterable[RouteRegistration]) -> None: + def post_construct(self) -> None: ... + + def _register_decorated_routes(self, routes: Iterable[RouteRegistration]) -> None: for route in routes: bound_method = partial(route.func, self) self.router.add_api_route( diff --git a/py_spring_core/core/entities/entity_provider.py b/py_spring_core/core/entities/entity_provider.py index aeb47e5..79c21a8 100644 --- a/py_spring_core/core/entities/entity_provider.py +++ b/py_spring_core/core/entities/entity_provider.py @@ -25,7 +25,7 @@ class EntityProvider: extneral_dependencies: list[Any] = field(default_factory=list) app_context: Optional["ApplicationContext"] = None - def get_entities(self) -> list[Type[object]]: + def get_entities(self) -> list[Type[AppEntities]]: return [ *self.component_classes, *self.bean_collection_classes, diff --git a/tests/test_entity_provider.py b/tests/test_entity_provider.py index fad6f24..d150797 100644 --- a/tests/test_entity_provider.py +++ b/tests/test_entity_provider.py @@ -24,7 +24,7 @@ def test_app_context( self, test_entity_provider: EntityProvider ) -> ApplicationContext: app_context = ApplicationContext(ApplicationContextConfig(properties_path="")) - app_context.register_entity_provider(test_entity_provider) + app_context.providers.append(test_entity_provider) return app_context def test_did_raise_error_when_no_depends_on_is_provided( From 1508ee944c0420ce677e6775354f8a190c2d5e64 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Sat, 5 Jul 2025 21:17:53 +0800 Subject: [PATCH 19/42] chore: Add __all__ export for public API in __init__.py (#8) --- py_spring_core/__init__.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 33bf03f..828b57f 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -16,4 +16,23 @@ from py_spring_core.event.commons import ApplicationEvent from py_spring_core.event.application_event_handler_registry import EventListener -__version__ = "0.0.13" +__version__ = "0.0.14" + +__all__ = [ + "PySpringApplication", + "BeanCollection", + "Component", + "ComponentScope", + "RestController", + "DeleteMapping", + "GetMapping", + "PatchMapping", + "PostMapping", + "PutMapping", + "EntityProvider", + "Properties", + "ApplicationContextRequired", + "ApplicationEventPublisher", + "ApplicationEvent", + "EventListener", +] \ No newline at end of file From 08143281a8605e773244325a455a43a9e798ef63 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:49:13 +0800 Subject: [PATCH 20/42] Enhanced Logging Configuration and Uvicorn Integration (#9) --- py_spring_core/__init__.py | 2 +- .../core/application/loguru_config.py | 5 +++ .../core/application/py_spring_application.py | 37 ++++++++++++++++++- .../application_event_handler_registry.py | 2 +- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 828b57f..fa00b16 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -16,7 +16,7 @@ from py_spring_core.event.commons import ApplicationEvent from py_spring_core.event.application_event_handler_registry import EventListener -__version__ = "0.0.14" +__version__ = "0.0.15" __all__ = [ "PySpringApplication", diff --git a/py_spring_core/core/application/loguru_config.py b/py_spring_core/core/application/loguru_config.py index f291ff2..d4b16db 100644 --- a/py_spring_core/core/application/loguru_config.py +++ b/py_spring_core/core/application/loguru_config.py @@ -14,6 +14,10 @@ class LogLevel(str, Enum): CRITICAL = "CRITICAL" +class LogFormat(str, Enum): + TEXT = "text" + JSON = "json" + class LoguruConfig(BaseModel): log_format: str = ( "{time:YYYY-MM-DD HH:mm:ss.SSS} | " @@ -27,3 +31,4 @@ class LoguruConfig(BaseModel): log_file_path: Optional[str] = "./logs/app.log" enable_backtrace: bool = True enable_diagnose: bool = True + format: LogFormat = LogFormat.TEXT diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 375c9d3..7582f06 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -1,3 +1,4 @@ +import logging import os from typing import Any, Callable, Iterable, Type @@ -21,6 +22,7 @@ from py_spring_core.core.application.context.application_context_config import ( ApplicationContextConfig, ) +from py_spring_core.core.application.loguru_config import LogFormat from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component, ComponentLifeCycle from py_spring_core.core.entities.controllers.rest_controller import RestController @@ -97,14 +99,16 @@ def __configure_logging(self): config = self.app_config.loguru_config if not config.log_file_path: return - + + # Use the format field from config which contains the actual format string logger.add( config.log_file_path, - format=config.log_format, level=config.log_level, rotation=config.log_rotation, retention=config.log_retention, + serialize=config.format == LogFormat.JSON, ) + self.__configure_uvicorn_logging() def _get_system_managed_classes(self) -> Iterable[Type[Component]]: return [ @@ -209,11 +213,40 @@ def __init_controllers(self) -> None: self.fastapi.include_router(router) controller.register_middlewares() + def __configure_uvicorn_logging(self): + """Configure Uvicorn to use Loguru instead of default logging.""" + + + # Intercept standard logging and redirect to loguru + class InterceptHandler(logging.Handler): + def emit(self, record): + # Get corresponding Loguru level if it exists + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + # Find caller from where originated the logged message + frame, depth = logging.currentframe(), 2 + while frame and frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) + + # Remove default uvicorn logger and add intercept handler + logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True) + def __run_server(self) -> None: + # Configure Uvicorn to use Loguru + + + # Run uvicorn server uvicorn.run( self.fastapi, host=self.app_config.server_config.host, port=self.app_config.server_config.port, + log_config=None, # Disable uvicorn's default logging ) def run(self) -> None: diff --git a/py_spring_core/event/application_event_handler_registry.py b/py_spring_core/event/application_event_handler_registry.py index 149c75f..d2c2863 100644 --- a/py_spring_core/event/application_event_handler_registry.py +++ b/py_spring_core/event/application_event_handler_registry.py @@ -61,7 +61,7 @@ def post_construct(self) -> None: logger.info("Initializing event handlers...") self._init_event_handlers() logger.info("Starting event message handler thread...") - Thread(target= self._handle_messages).start() + Thread(target= self._handle_messages, daemon=True).start() def _init_event_handlers(self) -> None: app_context = self.get_application_context() From 53d295c69de5fbfc7d30386c8d47db4eedecdb9a Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:27:14 +0800 Subject: [PATCH 21/42] chore:add server to context (#10) --- py_spring_core/__init__.py | 2 +- .../core/application/context/application_context.py | 4 +++- py_spring_core/core/application/py_spring_application.py | 3 ++- tests/test_application_context.py | 9 +++++++-- tests/test_component_features.py | 9 +++++++-- tests/test_entity_provider.py | 9 +++++++-- 6 files changed, 27 insertions(+), 9 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index fa00b16..ade5654 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -16,7 +16,7 @@ from py_spring_core.event.commons import ApplicationEvent from py_spring_core.event.application_event_handler_registry import EventListener -__version__ = "0.0.15" +__version__ = "0.0.16" __all__ = [ "PySpringApplication", diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index 3e5262a..6ddb02b 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -12,6 +12,7 @@ get_origin, ) +from fastapi import FastAPI from loguru import logger from pydantic import BaseModel @@ -57,7 +58,8 @@ class ApplicationContext: The `ApplicationContext` class is designed to follow the Singleton design pattern, ensuring that there is a single instance of the application context throughout the application's lifetime. """ - def __init__(self, config: ApplicationContextConfig) -> None: + def __init__(self, config: ApplicationContextConfig, server: FastAPI) -> None: + self.server = server self.all_file_paths: set[str] = set() self.primitive_types = (bool, str, int, float, type(None)) diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 7582f06..77e3cac 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -79,8 +79,9 @@ def __init__( self.app_context_config = ApplicationContextConfig( properties_path=self.app_config.properties_file_path ) - self.app_context = ApplicationContext(config=self.app_context_config) self.fastapi = FastAPI() + self.app_context = ApplicationContext(config=self.app_context_config, server=self.fastapi) + self.classes_with_handlers: dict[ Type[AppEntities], Callable[[Type[Any]], None] diff --git a/tests/test_application_context.py b/tests/test_application_context.py index 6f4b71b..c9e2e19 100644 --- a/tests/test_application_context.py +++ b/tests/test_application_context.py @@ -1,3 +1,4 @@ +from fastapi import FastAPI import pytest from py_spring_core.core.application.context.application_context import ( @@ -12,9 +13,13 @@ class TestApplicationContext: @pytest.fixture - def app_context(self): + def server(self) -> FastAPI: + return FastAPI() + + @pytest.fixture + def app_context(self, server: FastAPI): config = ApplicationContextConfig(properties_path="") - return ApplicationContext(config) + return ApplicationContext(config, server=server) def test_register_entities_correctly(self, app_context: ApplicationContext): class TestComponent(Component): ... diff --git a/tests/test_component_features.py b/tests/test_component_features.py index b41b217..3aa76b2 100644 --- a/tests/test_component_features.py +++ b/tests/test_component_features.py @@ -1,6 +1,7 @@ from abc import ABC from typing import Annotated +from fastapi import FastAPI import pytest from py_spring_core.core.application.context.application_context import ( @@ -14,10 +15,14 @@ class TestComponentFeatures: """Test suite for component features including primary components, qualifiers, and registration validation.""" @pytest.fixture - def app_context(self): + def server(self) -> FastAPI: + return FastAPI() + + @pytest.fixture + def app_context(self, server: FastAPI): """Fixture that provides a fresh ApplicationContext instance for each test.""" config = ApplicationContextConfig(properties_path="") - return ApplicationContext(config) + return ApplicationContext(config, server=server) def test_qualifier_based_injection(self, app_context: ApplicationContext): """ diff --git a/tests/test_entity_provider.py b/tests/test_entity_provider.py index d150797..e4a2a58 100644 --- a/tests/test_entity_provider.py +++ b/tests/test_entity_provider.py @@ -1,3 +1,4 @@ +from fastapi import FastAPI import pytest from py_spring_core.core.application.context.application_context import ( @@ -18,12 +19,16 @@ class TestEntityProvider: @pytest.fixture def test_entity_provider(self): return EntityProvider(depends_on=[TestComponent]) + + @pytest.fixture + def server(self) -> FastAPI: + return FastAPI() @pytest.fixture def test_app_context( - self, test_entity_provider: EntityProvider + self, test_entity_provider: EntityProvider, server: FastAPI ) -> ApplicationContext: - app_context = ApplicationContext(ApplicationContextConfig(properties_path="")) + app_context = ApplicationContext(ApplicationContextConfig(properties_path=""), server=server) app_context.providers.append(test_entity_provider) return app_context From f6ee731a79ca0b3c23e2fad885e428e468cdef5a Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Tue, 15 Jul 2025 13:12:01 +0800 Subject: [PATCH 22/42] Fix/correct log level (#11) --- py_spring_core/__init__.py | 2 +- py_spring_core/core/application/py_spring_application.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index ade5654..cb3e8ed 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -16,7 +16,7 @@ from py_spring_core.event.commons import ApplicationEvent from py_spring_core.event.application_event_handler_registry import EventListener -__version__ = "0.0.16" +__version__ = "0.0.17" __all__ = [ "PySpringApplication", diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 77e3cac..3ceda7c 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -216,8 +216,7 @@ def __init_controllers(self) -> None: def __configure_uvicorn_logging(self): """Configure Uvicorn to use Loguru instead of default logging.""" - - + # Configure Uvicorn to use Loguru # Intercept standard logging and redirect to loguru class InterceptHandler(logging.Handler): def emit(self, record): @@ -236,10 +235,11 @@ def emit(self, record): logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) # Remove default uvicorn logger and add intercept handler - logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True) + log_level = self.app_config.loguru_config.log_level.value + logging.basicConfig(handlers=[InterceptHandler()], level=log_level, force=True) def __run_server(self) -> None: - # Configure Uvicorn to use Loguru + # Run uvicorn server From 4e4968937c37e4223dfe3eecf8cc1e84809ce32d Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Fri, 18 Jul 2025 01:44:34 +0800 Subject: [PATCH 23/42] Enhanced Abstract Class Component Initialization Support (#12) --- py_spring_core/__init__.py | 2 +- .../context/application_context.py | 50 ++++++++++++++----- py_spring_core/core/utils.py | 31 +++++++++++- 3 files changed, 69 insertions(+), 14 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index cb3e8ed..2340789 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -16,7 +16,7 @@ from py_spring_core.event.commons import ApplicationEvent from py_spring_core.event.application_event_handler_registry import EventListener -__version__ = "0.0.17" +__version__ = "0.0.18" __all__ = [ "PySpringApplication", diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index 6ddb02b..d7f4d21 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -1,3 +1,5 @@ +import py_spring_core.core.utils as framework_utils + from abc import ABC from inspect import isclass from typing import ( @@ -104,7 +106,6 @@ def _determine_target_cls_name( 4. If the component class is an ABC and has no implementations, return the name of the first subclass. 5. If the component class is an ABC and has multiple implementations, raise an error. """ - if qualifier is not None: return qualifier @@ -140,9 +141,7 @@ def get_component( scope = component_cls.get_scope() match scope: case ComponentScope.Singleton: - optional_instance = self.singleton_component_instance_container.get( - target_cls_name - ) + optional_instance = self.singleton_component_instance_container.get(target_cls_name) return cast(T, optional_instance) case ComponentScope.Prototype: @@ -185,10 +184,13 @@ def register_component(self, component_cls: Type[Component]) -> None: f"[COMPONENT REGISTRATION ERROR] Component: {component_cls} is not a subclass of Component" ) component_cls_name = component_cls.get_name() - if component_cls_name in self.component_cls_container: - raise ValueError( - f"[COMPONENT REGISTRATION ERROR] Component: {component_cls_name} already registered" - ) + is_same_component = ( + component_cls_name in self.component_cls_container and self.component_cls_container[component_cls_name].__name__ == component_cls.__name__ + and self.component_cls_container[component_cls_name] == component_cls + ) + if is_same_component: + return + self.component_cls_container[component_cls_name] = component_cls def register_controller(self, controller_cls: Type[RestController]) -> None: @@ -268,6 +270,12 @@ def init_singleton_component( return instance + def get_abstract_class_component_subclasses(self, component_cls: Type[ABC]) -> list[Type[Component]]: + return [ + subclass for subclass in component_cls.__subclasses__() + if issubclass(subclass, Component) + ] + def init_ioc_container(self) -> None: """ Initializes the IoC (Inversion of Control) container by creating singleton instances of all registered components. @@ -283,10 +291,28 @@ def init_ioc_container(self) -> None: logger.debug( f"[INITIALIZING SINGLETON COMPONENT] Init singleton component: {component_cls_name}" ) - instance = self.init_singleton_component(component_cls, component_cls_name) - if instance is None: - continue - self.singleton_component_instance_container[component_cls_name] = instance + if issubclass(component_cls, ABC): + component_classes = self.get_abstract_class_component_subclasses(component_cls) + for subclass_component_cls in component_classes: + self.register_component(subclass_component_cls) + unimplemented_abstract_methods = framework_utils.get_unimplemented_abstract_methods(subclass_component_cls) + if len(unimplemented_abstract_methods) > 0: + unimplemented_abstract_methods_str = ", ".join(unimplemented_abstract_methods) + message = f"[ABSTRACT CLASS COMPONENT INITIALIZING SINGLETON COMPONENT] Unable to initialize singleton component: {subclass_component_cls.get_name()} because it has unimplemented abstract methods: {unimplemented_abstract_methods_str}" + logger.error(message) + raise ValueError(message) + logger.debug( + f"[ABSTRACT CLASS COMPONENT INITIALIZING SINGLETON COMPONENT] Init singleton component: {subclass_component_cls.get_name()}" + ) + instance = self.init_singleton_component(subclass_component_cls, subclass_component_cls.get_name()) + if instance is None: + continue + self.singleton_component_instance_container[subclass_component_cls.get_name()] = instance + else: + instance = self.init_singleton_component(component_cls, component_cls.get_name()) + if instance is None: + continue + self.singleton_component_instance_container[component_cls_name] = instance # for Bean for ( diff --git a/py_spring_core/core/utils.py b/py_spring_core/core/utils.py index bbc6e8f..6a26683 100644 --- a/py_spring_core/core/utils.py +++ b/py_spring_core/core/utils.py @@ -1,7 +1,8 @@ +from abc import ABC import importlib.util import inspect from pathlib import Path -from typing import Iterable, Type +from typing import Any, Iterable, Type from loguru import logger @@ -75,3 +76,31 @@ def dynamically_import_modules( returned_target_classes.add(loaded_class) return returned_target_classes + + +def get_unimplemented_abstract_methods(cls: Type[Any]) -> set[str]: + """ + Returns a set of abstract method names not implemented in the given class. + Args: + cls (Type[Any]): A subclass of abc.ABC + + Returns: + set[str]: A set of method names that are abstract but not yet implemented + """ + if not isinstance(cls, type): + raise TypeError("Expected a class type.") + + if not issubclass(cls, ABC): + raise TypeError("Expected a subclass of abc.ABC.") + + abstract_methods: set[str] = set() + for base in cls.__mro__: + base_abstracts = getattr(base, '__abstractmethods__', set()) + abstract_methods = abstract_methods.union(base_abstracts) + + implemented_methods: set[str] = { + attr for attr in dir(cls) + if callable(getattr(cls, attr)) and not getattr(getattr(cls, attr), '__isabstractmethod__', False) + } + + return abstract_methods.difference(implemented_methods) From 4f3d1b304c0d6a17459a88a1edb9730ff0e90bde Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Fri, 18 Jul 2025 13:44:44 +0800 Subject: [PATCH 24/42] Refactor Application Context and Dependency Management (#13) --- .../context/application_context.py | 723 +++++++++++------- tests/test_application_context.py | 30 +- tests/test_component_features.py | 28 +- 3 files changed, 479 insertions(+), 302 deletions(-) diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index d7f4d21..729c528 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -1,4 +1,4 @@ -import py_spring_core.core.utils as framework_utils +import py_spring_core.core.utils as framework_utils from abc import ABC from inspect import isclass @@ -37,381 +37,558 @@ PT = TypeVar("PT", bound=Properties) -class ComponentNotFoundError(Exception): ... +class ComponentNotFoundError(Exception): + """Raised when a component is not found in the application context.""" + pass -class InvalidDependencyError(Exception): ... +class InvalidDependencyError(Exception): + """Raised when a dependency is invalid or not found in the application context.""" + pass class ApplicationContextView(BaseModel): + """View model for application context state.""" config: ApplicationContextConfig component_cls_container: list[str] singleton_component_instance_container: list[str] -class ApplicationContext: - """ - The `ApplicationContext` class is the main entry point for the application's context management. - It is responsible for: - 1. Registering and managing the lifecycle of components, controllers, bean collections, and properties. - 2. Providing methods to retrieve instances of registered components, beans, and properties. - 3. Initializing the Inversion of Control (IoC) container by creating singleton instances of registered components. - 4. Injecting dependencies for registered components and controllers. - The `ApplicationContext` class is designed to follow the Singleton design pattern, ensuring that there is a single instance of the application context throughout the application's lifetime. - """ - - def __init__(self, config: ApplicationContextConfig, server: FastAPI) -> None: - self.server = server - self.all_file_paths: set[str] = set() - self.primitive_types = (bool, str, int, float, type(None)) - - self.config = config +class ContainerManager: + """Manages containers for different types of entities in the application context.""" + + def __init__(self): self.component_cls_container: dict[str, Type[Component]] = {} self.controller_cls_container: dict[str, Type[RestController]] = {} self.singleton_component_instance_container: dict[str, Component] = {} - + self.bean_collection_cls_container: dict[str, Type[BeanCollection]] = {} self.singleton_bean_instance_container: dict[str, object] = {} - + self.properties_cls_container: dict[str, Type[Properties]] = {} self.singleton_properties_instance_container: dict[str, Properties] = {} - self.providers: list[EntityProvider] = [] + + def is_entity_in_container(self, entity_cls: Type[AppEntities]) -> bool: + """Check if an entity class is registered in any container.""" + cls_name = entity_cls.__name__ + return ( + cls_name in self.component_cls_container + or cls_name in self.controller_cls_container + or cls_name in self.bean_collection_cls_container + or cls_name in self.properties_cls_container + ) - def set_all_file_paths(self, all_file_paths: set[str]) -> None: - self.all_file_paths = all_file_paths - def _create_properties_loader(self) -> _PropertiesLoader: - return _PropertiesLoader( - self.config.properties_path, list(self.properties_cls_container.values()) - ) +class DependencyInjector: + """Handles dependency injection for entities in the application context.""" + + def __init__(self, container_manager: ContainerManager): + self.container_manager = container_manager + self.primitive_types = (bool, str, int, float, type(None)) + self._app_context: Optional['ApplicationContext'] = None + + def _extract_qualifier_from_annotation(self, annotated_type: Type) -> tuple[Type, Optional[str]]: + """Extract the actual type and qualifier from an Annotated type.""" + qualifier = None + if get_origin(annotated_type) is Annotated: + type_args = get_args(annotated_type) + annotated_type = type_args[0] + if len(type_args) > 1: + qualifier = type_args[1] + return annotated_type, qualifier + + def _inject_properties_dependency( + self, + entity: Type[AppEntities], + attr_name: str, + properties_cls: Type[Properties] + ) -> bool: + """Inject a properties dependency into an entity.""" + if self._app_context is None: + return False + + optional_properties = self._app_context.get_properties(properties_cls) + if optional_properties is None: + raise TypeError( + f"[PROPERTIES INJECTION ERROR] Properties: {properties_cls.get_name()} " + f"is not found in properties file for class: {properties_cls.get_name()} " + f"with key: {properties_cls.get_key()}" + ) + setattr(entity, attr_name, optional_properties) + return True + + def _try_inject_entity_dependency( + self, + entity: Type[AppEntities], + attr_name: str, + entity_cls: Type[AppEntities], + qualifier: Optional[str] + ) -> bool: + """Try to inject an entity dependency using available getters.""" + if self._app_context is None: + return False + + entity_getters: list[Callable[[Type[AppEntities], Optional[str]], Optional[AppEntities]]] = [ + self._app_context.get_component, + self._app_context.get_bean + ] + + for getter in entity_getters: + optional_entity = getter(entity_cls, qualifier) + if optional_entity is not None: + setattr(entity, attr_name, optional_entity) + logger.success( + f"[DEPENDENCY INJECTION SUCCESS] Inject dependency for {entity_cls.__name__} " + f"in attribute: {attr_name}" + ) + return True + return False + + def inject_dependencies(self, entity: Type[AppEntities]) -> None: + """Inject dependencies for a given entity based on its annotations.""" + for attr_name, annotated_type in entity.__annotations__.items(): + entity_cls, qualifier = self._extract_qualifier_from_annotation(annotated_type) + + # Skip primitive types + if entity_cls in self.primitive_types: + logger.warning( + f"[DEPENDENCY INJECTION SKIPPED] Skip inject dependency for attribute: {attr_name} " + f"with dependency: {entity_cls.__name__} because it is primitive type" + ) + continue + + # Skip non-class types + if not isclass(entity_cls): + continue + + # Handle Properties injection + if issubclass(entity_cls, Properties): + if self._inject_properties_dependency(entity, attr_name, entity_cls): + continue + + # Try to inject entity dependency + if self._try_inject_entity_dependency(entity, attr_name, entity_cls, qualifier): + continue + + # If we get here, injection failed + error_message = ( + f"[DEPENDENCY INJECTION FAILED] Fail to inject dependency for attribute: {attr_name} " + f"with dependency: {entity_cls.__name__} with qualifier: {qualifier}, " + f"consider register such dependency with Component decorator" + ) + logger.critical(error_message) + raise ValueError(error_message) - def as_view(self) -> ApplicationContextView: - return ApplicationContextView( - config=self.config, - component_cls_container=list(self.component_cls_container.keys()), - singleton_component_instance_container=list( - self.singleton_component_instance_container.keys() - ), - ) +class ComponentManager: + """Manages component registration and instantiation.""" + + def __init__(self, container_manager: ContainerManager): + self.container_manager = container_manager + def _determine_target_cls_name( self, component_cls: Type[T], qualifier: Optional[str] ) -> str: """ Determine the target class name for a given component class. - This method handles the following cases: - 1. If a qualifier is provided, return it directly. - 2. If the component class is not an ABC, return its name directly. - 3. If the component class is an ABC but has implementations, return its name directly. - 4. If the component class is an ABC and has no implementations, return the name of the first subclass. - 5. If the component class is an ABC and has multiple implementations, raise an error. + + Args: + component_cls: The component class to determine the name for + qualifier: Optional qualifier to use directly + + Returns: + The target class name + + Raises: + ValueError: If abstract class has no subclasses """ if qualifier is not None: return qualifier - + # If it's not an ABC, return its name directly if not issubclass(component_cls, ABC): return component_cls.get_name() - + # If it's an ABC but has implementations, return its name directly if not component_cls.__abstractmethods__: return component_cls.get_name() - + # For abstract classes that need implementations subclasses = component_cls.__subclasses__() if len(subclasses) == 0: raise ValueError( f"[ABSTRACT CLASS ERROR] Abstract class {component_cls.__name__} has no subclasses" ) - + # Fall back to first subclass if no primary component exists return subclasses[0].get_name() - - def get_component( - self, component_cls: Type[T], qualifier: Optional[str] - ) -> Optional[T]: + + def register_component(self, component_cls: Type[Component]) -> None: + """Register a component class in the container.""" + if not issubclass(component_cls, Component): + raise TypeError( + f"[COMPONENT REGISTRATION ERROR] Component: {component_cls} " + f"is not a subclass of Component" + ) + + component_cls_name = component_cls.get_name() + existing_component = self.container_manager.component_cls_container.get(component_cls_name) + + # Check if it's the same component to avoid duplicate registration + if (existing_component and + existing_component.__name__ == component_cls.__name__ and + existing_component == component_cls): + return + + self.container_manager.component_cls_container[component_cls_name] = component_cls + + def get_component(self, component_cls: Type[T], qualifier: Optional[str]) -> Optional[T]: + """Get a component instance by class and optional qualifier.""" if not issubclass(component_cls, (Component, ABC)): return None - - target_cls_name: str = self._determine_target_cls_name(component_cls, qualifier) - - if target_cls_name not in self.component_cls_container: + + target_cls_name = self._determine_target_cls_name(component_cls, qualifier) + + if target_cls_name not in self.container_manager.component_cls_container: return None - + scope = component_cls.get_scope() match scope: case ComponentScope.Singleton: - optional_instance = self.singleton_component_instance_container.get(target_cls_name) - return cast(T, optional_instance) - + return cast(T, self.container_manager.singleton_component_instance_container.get(target_cls_name)) case ComponentScope.Prototype: - prototype_instance = component_cls() - return cast(T, prototype_instance) - - def is_within_context(self, _cls: Type[AppEntities]) -> bool: - cls_name = _cls.__name__ - is_within_component = cls_name in self.component_cls_container - is_within_controller = cls_name in self.controller_cls_container - is_within_bean_collection = cls_name in self.bean_collection_cls_container - is_within_properties = cls_name in self.properties_cls_container - return ( - is_within_component - or is_within_controller - or is_within_bean_collection - or is_within_properties - ) - - def get_bean(self, object_cls: Type[T], qualifier: Optional[str]) -> Optional[T]: - bean_name = object_cls.__name__ - if bean_name not in self.singleton_bean_instance_container: - return None - - optional_instance = self.singleton_bean_instance_container.get(bean_name) - return cast(T, optional_instance) - - def get_properties(self, properties_cls: Type[PT]) -> Optional[PT]: - properties_cls_name = properties_cls.get_key() - if properties_cls_name not in self.properties_cls_container: - return None - optional_instance = cast( - PT, self.singleton_properties_instance_container.get(properties_cls_name) - ) - return optional_instance - - def register_component(self, component_cls: Type[Component]) -> None: - if not issubclass(component_cls, Component): - raise TypeError( - f"[COMPONENT REGISTRATION ERROR] Component: {component_cls} is not a subclass of Component" + return cast(T, component_cls()) + + def _init_singleton_component( + self, component_cls: Type[Component], component_cls_name: str + ) -> Optional[Component]: + """Initialize a singleton component instance.""" + try: + return component_cls() + except Exception as error: + if "Can't instantiate abstract class" in str(error): + logger.warning( + f"[INITIALIZING SINGLETON COMPONENT ERROR] Skip initializing singleton component: " + f"{component_cls_name} because it is an abstract class" + ) + return None + logger.error( + f"[INITIALIZING SINGLETON COMPONENT ERROR] Error initializing singleton component: " + f"{component_cls_name} with error: {error}" ) - component_cls_name = component_cls.get_name() - is_same_component = ( - component_cls_name in self.component_cls_container and self.component_cls_container[component_cls_name].__name__ == component_cls.__name__ - and self.component_cls_container[component_cls_name] == component_cls - ) - if is_same_component: - return + raise error + + def _get_abstract_class_component_subclasses(self, component_cls: Type[ABC]) -> list[Type[Component]]: + """Get all Component subclasses of an abstract class.""" + return [ + subclass for subclass in component_cls.__subclasses__() + if issubclass(subclass, Component) + ] + + def _init_abstract_component_subclasses(self, component_cls: Type[ABC]) -> None: + """Initialize singleton instances for abstract component subclasses.""" + component_classes = self._get_abstract_class_component_subclasses(component_cls) - self.component_cls_container[component_cls_name] = component_cls - - def register_controller(self, controller_cls: Type[RestController]) -> None: - if not issubclass(controller_cls, RestController): - raise TypeError( - f"[CONTROLLER REGISTRATION ERROR] Controller: {controller_cls} is not a subclass of RestController" + for subclass_component_cls in component_classes: + self.register_component(subclass_component_cls) + + # Check for unimplemented abstract methods + unimplemented_methods = framework_utils.get_unimplemented_abstract_methods(subclass_component_cls) + if unimplemented_methods: + methods_str = ", ".join(unimplemented_methods) + message = ( + f"[ABSTRACT CLASS COMPONENT INITIALIZING SINGLETON COMPONENT] " + f"Unable to initialize singleton component: {subclass_component_cls.get_name()} " + f"because it has unimplemented abstract methods: {methods_str}" + ) + logger.error(message) + raise ValueError(message) + + logger.debug( + f"[ABSTRACT CLASS COMPONENT INITIALIZING SINGLETON COMPONENT] " + f"Init singleton component: {subclass_component_cls.get_name()}" ) - - controller_cls_name = controller_cls.get_name() - self.controller_cls_container[controller_cls_name] = controller_cls - + + instance = self._init_singleton_component(subclass_component_cls, subclass_component_cls.get_name()) + if instance is not None: + self.container_manager.singleton_component_instance_container[subclass_component_cls.get_name()] = instance + + def init_singleton_components(self) -> None: + """Initialize all singleton components in the container.""" + for component_cls_name, component_cls in self.container_manager.component_cls_container.items(): + if component_cls.get_scope() != ComponentScope.Singleton: + continue + + logger.debug(f"[INITIALIZING SINGLETON COMPONENT] Init singleton component: {component_cls_name}") + + if issubclass(component_cls, ABC): + self._init_abstract_component_subclasses(component_cls) + else: + instance = self._init_singleton_component(component_cls, component_cls_name) + if instance is not None: + self.container_manager.singleton_component_instance_container[component_cls_name] = instance + + +class BeanManager: + """Manages bean collection registration and instantiation.""" + + def __init__(self, container_manager: ContainerManager, dependency_injector: DependencyInjector): + self.container_manager = container_manager + self.dependency_injector = dependency_injector + def register_bean_collection(self, bean_cls: Type[BeanCollection]) -> None: + """Register a bean collection class in the container.""" if not issubclass(bean_cls, BeanCollection): raise TypeError( - f"[BEAN COLLECTION REGISTRATION ERROR] BeanCollection: {bean_cls} is not a subclass of BeanCollection" + f"[BEAN COLLECTION REGISTRATION ERROR] BeanCollection: {bean_cls} " + f"is not a subclass of BeanCollection" ) - + bean_name = bean_cls.get_name() - self.bean_collection_cls_container[bean_name] = bean_cls + self.container_manager.bean_collection_cls_container[bean_name] = bean_cls + + def get_bean(self, object_cls: Type[T], qualifier: Optional[str] = None) -> Optional[T]: + """Get a bean instance by class and optional qualifier.""" + bean_name = object_cls.__name__ + if bean_name not in self.container_manager.singleton_bean_instance_container: + return None + + return cast(T, self.container_manager.singleton_bean_instance_container.get(bean_name)) + + def _inject_bean_collection_dependencies(self, bean_collection_cls: Type[BeanCollection]) -> None: + """Inject dependencies for a bean collection.""" + logger.info( + f"[BEAN COLLECTION DEPENDENCY INJECTION] Injecting dependencies for {bean_collection_cls.get_name()}" + ) + self.dependency_injector.inject_dependencies(bean_collection_cls) + + def _validate_bean_view(self, view, collection_name: str) -> None: + """Validate a bean view before adding it to the container.""" + if view.bean_name in self.container_manager.singleton_bean_instance_container: + raise BeanConflictError( + f"[BEAN CONFLICTS] Bean: {view.bean_name} already exists under collection: {collection_name}" + ) + + if not view.is_valid_bean(): + raise InvalidBeanError( + f"[INVALID BEAN] Bean name from bean creation func return type: {view.bean_name} " + f"does not match the bean object class name: {view.bean.__class__.__name__}" + ) + + def init_singleton_beans(self) -> None: + """Initialize all singleton beans from registered bean collections.""" + for bean_collection_cls_name, bean_collection_cls in self.container_manager.bean_collection_cls_container.items(): + logger.debug(f"[INITIALIZING SINGLETON BEAN] Init singleton bean: {bean_collection_cls_name}") + + collection = bean_collection_cls() + self._inject_bean_collection_dependencies(bean_collection_cls) + + bean_views = collection.scan_beans() + for view in bean_views: + self._validate_bean_view(view, collection.get_name()) + self.container_manager.singleton_bean_instance_container[view.bean_name] = view.bean + +class PropertiesManager: + """Manages properties registration and loading.""" + + def __init__(self, container_manager: ContainerManager, config: ApplicationContextConfig): + self.container_manager = container_manager + self.config = config + def register_properties(self, properties_cls: Type[Properties]) -> None: + """Register a properties class in the container.""" if not issubclass(properties_cls, Properties): raise TypeError( - f"[PROPERTIES REGISTRATION ERROR] Properties: {properties_cls} is not a subclass of Properties" + f"[PROPERTIES REGISTRATION ERROR] Properties: {properties_cls} " + f"is not a subclass of Properties" ) + properties_name = properties_cls.get_key() - self.properties_cls_container[properties_name] = properties_cls - - def get_controller_instances(self) -> list[RestController]: - return [_cls() for _cls in self.controller_cls_container.values()] - - def get_singleton_component_instances(self) -> list[Component]: - return [_cls for _cls in self.singleton_component_instance_container.values()] - - def get_singleton_bean_instances(self) -> list[object]: - return [_cls for _cls in self.singleton_bean_instance_container.values()] - + self.container_manager.properties_cls_container[properties_name] = properties_cls + + def get_properties(self, properties_cls: Type[PT]) -> Optional[PT]: + """Get a properties instance by class.""" + properties_cls_name = properties_cls.get_key() + if properties_cls_name not in self.container_manager.properties_cls_container: + return None + + return cast(PT, self.container_manager.singleton_properties_instance_container.get(properties_cls_name)) + + def _create_properties_loader(self) -> _PropertiesLoader: + """Create a properties loader instance.""" + return _PropertiesLoader( + self.config.properties_path, + list(self.container_manager.properties_cls_container.values()) + ) + def load_properties(self) -> None: + """Load all registered properties from configuration files.""" properties_loader = self._create_properties_loader() properties_instance_dict = properties_loader.load_properties() - for properties_key, properties_cls in self.properties_cls_container.items(): - if properties_key in self.singleton_properties_instance_container: + + for properties_key, properties_cls in self.container_manager.properties_cls_container.items(): + if properties_key in self.container_manager.singleton_properties_instance_container: continue - - logger.debug( - f"[INITIALIZING SINGLETON PROPERTIES] Init singleton properties: {properties_key}" - ) + + logger.debug(f"[INITIALIZING SINGLETON PROPERTIES] Init singleton properties: {properties_key}") + optional_properties = properties_instance_dict.get(properties_key) if optional_properties is None: raise TypeError( - f"[PROPERTIES INITIALIZATION ERROR] Properties: {properties_key} is not found in properties file for class: {properties_cls.get_name()} with key: {properties_cls.get_key()}" + f"[PROPERTIES INITIALIZATION ERROR] Properties: {properties_key} " + f"is not found in properties file for class: {properties_cls.get_name()} " + f"with key: {properties_cls.get_key()}" ) - self.singleton_properties_instance_container[properties_key] = ( - optional_properties - ) - _PropertiesLoader.optional_loaded_properties = ( - self.singleton_properties_instance_container + + self.container_manager.singleton_properties_instance_container[properties_key] = optional_properties + + # Update the global properties loader reference + _PropertiesLoader.optional_loaded_properties = self.container_manager.singleton_properties_instance_container + + +class ApplicationContext: + """ + The main entry point for the application's context management. + + This class is responsible for: + 1. Registering and managing the lifecycle of components, controllers, bean collections, and properties. + 2. Providing methods to retrieve instances of registered components, beans, and properties. + 3. Initializing the Inversion of Control (IoC) container by creating singleton instances of registered components. + 4. Injecting dependencies for registered components and controllers. + + The ApplicationContext class is designed to follow the Singleton design pattern, ensuring that there is + a single instance of the application context throughout the application's lifetime. + """ + + def __init__(self, config: ApplicationContextConfig, server: FastAPI) -> None: + self.server = server + self.config = config + self.all_file_paths: set[str] = set() + self.providers: list[EntityProvider] = [] + + # Initialize managers + self.container_manager = ContainerManager() + self.dependency_injector = DependencyInjector(self.container_manager) + self.component_manager = ComponentManager(self.container_manager) + self.bean_manager = BeanManager(self.container_manager, self.dependency_injector) + self.properties_manager = PropertiesManager(self.container_manager, config) + + # Set app context reference for dependency injection + self.dependency_injector._app_context = self + + def set_all_file_paths(self, all_file_paths: set[str]) -> None: + """Set the collection of all file paths in the application.""" + self.all_file_paths = all_file_paths + + def as_view(self) -> ApplicationContextView: + """Create a view model of the application context state.""" + return ApplicationContextView( + config=self.config, + component_cls_container=list(self.container_manager.component_cls_container.keys()), + singleton_component_instance_container=list( + self.container_manager.singleton_component_instance_container.keys() + ), ) - def init_singleton_component( - self, component_cls: Type[Component], component_cls_name: str - ) -> Optional[Component]: - instance: Optional[Component] = None - try: - instance = component_cls() - except Exception as error: - unable_to_init_component_error_prefix = "Can't instantiate abstract class" - if unable_to_init_component_error_prefix in str(error): - logger.warning( - f"[INITIALIZING SINGLETON COMPONENT ERROR] Skip initializing singleton component: {component_cls_name} because it is an abstract class" - ) - return - logger.error( - f"[INITIALIZING SINGLETON COMPONENT ERROR] Error initializing singleton component: {component_cls_name} with error: {error}" - ) - raise error + # Component management methods + def get_component(self, component_cls: Type[T], qualifier: Optional[str] = None) -> Optional[T]: + """Get a component instance by class and optional qualifier.""" + return self.component_manager.get_component(component_cls, qualifier) - return instance + def register_component(self, component_cls: Type[Component]) -> None: + """Register a component class in the application context.""" + self.component_manager.register_component(component_cls) - def get_abstract_class_component_subclasses(self, component_cls: Type[ABC]) -> list[Type[Component]]: - return [ - subclass for subclass in component_cls.__subclasses__() - if issubclass(subclass, Component) - ] + # Bean management methods + def get_bean(self, object_cls: Type[T], qualifier: Optional[str] = None) -> Optional[T]: + """Get a bean instance by class and optional qualifier.""" + return self.bean_manager.get_bean(object_cls, qualifier) - def init_ioc_container(self) -> None: - """ - Initializes the IoC (Inversion of Control) container by creating singleton instances of all registered components. - This method iterates through the registered component classes in the `component_cls_container` dictionary. - For each component class with a `Singleton` scope, it creates an instance of the component and stores it in the `singleton_component_instance_container` dictionary. - This ensures that subsequent calls to `get_component()` for singleton components will return the same instance, as required by the Singleton design pattern. - """ + def register_bean_collection(self, bean_cls: Type[BeanCollection]) -> None: + """Register a bean collection class in the application context.""" + self.bean_manager.register_bean_collection(bean_cls) - # for Components - for component_cls_name, component_cls in self.component_cls_container.items(): - if component_cls.get_scope() != ComponentScope.Singleton: - continue - logger.debug( - f"[INITIALIZING SINGLETON COMPONENT] Init singleton component: {component_cls_name}" - ) - if issubclass(component_cls, ABC): - component_classes = self.get_abstract_class_component_subclasses(component_cls) - for subclass_component_cls in component_classes: - self.register_component(subclass_component_cls) - unimplemented_abstract_methods = framework_utils.get_unimplemented_abstract_methods(subclass_component_cls) - if len(unimplemented_abstract_methods) > 0: - unimplemented_abstract_methods_str = ", ".join(unimplemented_abstract_methods) - message = f"[ABSTRACT CLASS COMPONENT INITIALIZING SINGLETON COMPONENT] Unable to initialize singleton component: {subclass_component_cls.get_name()} because it has unimplemented abstract methods: {unimplemented_abstract_methods_str}" - logger.error(message) - raise ValueError(message) - logger.debug( - f"[ABSTRACT CLASS COMPONENT INITIALIZING SINGLETON COMPONENT] Init singleton component: {subclass_component_cls.get_name()}" - ) - instance = self.init_singleton_component(subclass_component_cls, subclass_component_cls.get_name()) - if instance is None: - continue - self.singleton_component_instance_container[subclass_component_cls.get_name()] = instance - else: - instance = self.init_singleton_component(component_cls, component_cls.get_name()) - if instance is None: - continue - self.singleton_component_instance_container[component_cls_name] = instance + # Properties management methods + def get_properties(self, properties_cls: Type[PT]) -> Optional[PT]: + """Get a properties instance by class.""" + return self.properties_manager.get_properties(properties_cls) - # for Bean - for ( - bean_collection_cls_name, - bean_collection_cls, - ) in self.bean_collection_cls_container.items(): - logger.debug( - f"[INITIALIZING SINGLETON BEAN] Init singleton bean: {bean_collection_cls_name}" + def register_properties(self, properties_cls: Type[Properties]) -> None: + """Register a properties class in the application context.""" + self.properties_manager.register_properties(properties_cls) + + def load_properties(self) -> None: + """Load all registered properties from configuration files.""" + self.properties_manager.load_properties() + + # Controller management methods + def register_controller(self, controller_cls: Type[RestController]) -> None: + """Register a controller class in the application context.""" + if not issubclass(controller_cls, RestController): + raise TypeError( + f"[CONTROLLER REGISTRATION ERROR] Controller: {controller_cls} " + f"is not a subclass of RestController" ) - collection = bean_collection_cls() - # before injecting_bean_collection deps and scanning beans, make sure properties is loaded in _PropertiesLoader by calling load_properties inside Application class - self._inject_dependencies_for_bean_collection(bean_collection_cls) - bean_views = collection.scan_beans() - for view in bean_views: - if view.bean_name in self.singleton_bean_instance_container: - raise BeanConflictError( - f"[BEAN CONFLICTS] Bean: {view.bean_name} already exists under collection: {collection.get_name()}" - ) - if not view.is_valid_bean(): - raise InvalidBeanError( - f"[INVALID BEAN] Bean name from bean creation func return type: {view.bean_name} does not match the bean object class name: {view.bean.__class__.__name__}" - ) - self.singleton_bean_instance_container[view.bean_name] = view.bean - - def _inject_entity_dependencies(self, entity: Type[AppEntities]) -> None: - for attr_name, annotated_entity_cls in entity.__annotations__.items(): - is_injected: bool = False - # Handle Annotated types - qualifier: Optional[str] = None - if get_origin(annotated_entity_cls) is Annotated: - annotated_entity_cls, qualifier_found = get_args(annotated_entity_cls) - if qualifier_found: - qualifier = qualifier_found - if annotated_entity_cls in self.primitive_types: - logger.warning( - f"[DEPENDENCY INJECTION SKIPPED] Skip inject dependency for attribute: {attr_name} with dependency: {annotated_entity_cls.__name__} because it is primitive type" - ) - continue - if not isclass(annotated_entity_cls): - continue + + controller_cls_name = controller_cls.get_name() + self.container_manager.controller_cls_container[controller_cls_name] = controller_cls - if issubclass(annotated_entity_cls, Properties): - optional_properties = self.get_properties(annotated_entity_cls) - if optional_properties is None: - raise TypeError( - f"[PROPERTIES INJECTION ERROR] Properties: {annotated_entity_cls.get_name()} is not found in properties file for class: {annotated_entity_cls.get_name()} with key: {annotated_entity_cls.get_key()}" - ) - setattr(entity, attr_name, optional_properties) - continue + def get_controller_instances(self) -> list[RestController]: + """Get all controller instances.""" + return [cls() for cls in self.container_manager.controller_cls_container.values()] - entity_getters: list[ - Callable[[Type[AppEntities], Optional[str]], Optional[AppEntities]] - ] = [self.get_component, self.get_bean] + def get_singleton_component_instances(self) -> list[Component]: + """Get all singleton component instances.""" + return list(self.container_manager.singleton_component_instance_container.values()) - for getter in entity_getters: - optional_entity = getter(annotated_entity_cls, qualifier) - if optional_entity is not None: - setattr(entity, attr_name, optional_entity) - is_injected = True - break + def get_singleton_bean_instances(self) -> list[object]: + """Get all singleton bean instances.""" + return list(self.container_manager.singleton_bean_instance_container.values()) - if is_injected: - logger.success( - f"[DEPENDENCY INJECTION SUCCESS FROM COMPONENT CONTAINER] Inject dependency for {annotated_entity_cls.__name__} in attribute: {attr_name} with dependency: {annotated_entity_cls.__name__} singleton instance" - ) - continue - error_message = f"[DEPENDENCY INJECTION FAILED] Fail to inject dependency for attribute: {attr_name} with dependency: {annotated_entity_cls.__name__} with qualifier: {qualifier}, consider register such depency with Compoent decorator" - logger.critical(error_message) - raise ValueError(error_message) + def is_within_context(self, entity_cls: Type[AppEntities]) -> bool: + """Check if an entity class is registered in the application context.""" + return self.container_manager.is_entity_in_container(entity_cls) - def _inject_dependencies_for_bean_collection( - self, bean_collection_cls: Type[BeanCollection] - ) -> None: - logger.info( - f"[BEAN COLLECTION DEPENDENCY INJECTION] Injecting dependencies for {bean_collection_cls.get_name()}" - ) - self._inject_entity_dependencies(bean_collection_cls) + def init_ioc_container(self) -> None: + """ + Initialize the IoC (Inversion of Control) container. + + This method creates singleton instances of all registered components and beans, + ensuring that subsequent calls to get_component() for singleton components + will return the same instance, as required by the Singleton design pattern. + """ + # Initialize singleton components + self.component_manager.init_singleton_components() + + # Initialize singleton beans + self.bean_manager.init_singleton_beans() def inject_dependencies_for_app_entities(self) -> None: + """Inject dependencies for all registered app entities.""" containers: list[Mapping[str, Type[AppEntities]]] = [ - self.component_cls_container, - self.controller_cls_container, + self.container_manager.component_cls_container, + self.container_manager.controller_cls_container, ] for container in containers: - for _cls_name, _cls in container.items(): - self._inject_entity_dependencies(_cls) + for cls_name, cls in container.items(): + self.dependency_injector.inject_dependencies(cls) def _validate_entity_provider_dependencies(self, provider: EntityProvider) -> None: + """Validate dependencies for a single entity provider.""" for dependency in provider.depends_on: if not issubclass(dependency, AppEntities): error = f"[INVALID DEPENDENCY] Invalid dependency {dependency.__name__} in {provider.__class__.__name__}" logger.error(error) raise InvalidDependencyError(error) + if not self.is_within_context(dependency): error = f"[INVALID DEPENDENCY] Dependency {dependency.__name__} not found in the application context" logger.error(error) raise InvalidDependencyError(error) def validate_entity_providers(self) -> None: + """Validate all entity providers in the application context.""" for provider in self.providers: self._validate_entity_provider_dependencies(provider) diff --git a/tests/test_application_context.py b/tests/test_application_context.py index c9e2e19..f306daa 100644 --- a/tests/test_application_context.py +++ b/tests/test_application_context.py @@ -37,21 +37,21 @@ class TestProperties(Properties): app_context.register_properties(TestProperties) assert ( - "TestComponent" in app_context.component_cls_container - and app_context.component_cls_container["TestComponent"] == TestComponent + "TestComponent" in app_context.container_manager.component_cls_container + and app_context.container_manager.component_cls_container["TestComponent"] == TestComponent ) assert ( - "TestController" in app_context.controller_cls_container - and app_context.controller_cls_container["TestController"] == TestController + "TestController" in app_context.container_manager.controller_cls_container + and app_context.container_manager.controller_cls_container["TestController"] == TestController ) assert ( - "TestBeanCollection" in app_context.bean_collection_cls_container - and app_context.bean_collection_cls_container["TestBeanCollection"] + "TestBeanCollection" in app_context.container_manager.bean_collection_cls_container + and app_context.container_manager.bean_collection_cls_container["TestBeanCollection"] == TestBeanCollection ) assert ( - "test_properties" in app_context.properties_cls_container - and app_context.properties_cls_container["test_properties"] + "test_properties" in app_context.container_manager.properties_cls_container + and app_context.container_manager.properties_cls_container["test_properties"] == TestProperties ) @@ -70,10 +70,10 @@ class TestProperties(Properties): app_context.register_bean_collection(TestBeanCollection) app_context.register_properties(TestProperties) - assert "TestComponent" in app_context.component_cls_container - assert "TestController" in app_context.controller_cls_container - assert "TestBeanCollection" in app_context.bean_collection_cls_container - assert "test_properties" in app_context.properties_cls_container + assert "TestComponent" in app_context.container_manager.component_cls_container + assert "TestController" in app_context.container_manager.controller_cls_container + assert "TestBeanCollection" in app_context.container_manager.bean_collection_cls_container + assert "test_properties" in app_context.container_manager.properties_cls_container def test_register_invalid_entities_raises_error( self, app_context: ApplicationContext @@ -128,7 +128,7 @@ class TestProperties(Properties): # Test retrieving singleton components component_instance = TestComponent() - app_context.singleton_component_instance_container["TestComponent"] = ( + app_context.container_manager.singleton_component_instance_container["TestComponent"] = ( component_instance ) retrieved_component = app_context.get_component(TestComponent, None) @@ -136,7 +136,7 @@ class TestProperties(Properties): # Test retrieving singleton beans bean_instance = TestBeanCollection() - app_context.singleton_bean_instance_container["TestBeanCollection"] = ( + app_context.container_manager.singleton_bean_instance_container["TestBeanCollection"] = ( bean_instance ) retrieved_bean = app_context.get_bean(TestBeanCollection, None) @@ -144,7 +144,7 @@ class TestProperties(Properties): # Test retrieving singleton properties properties_instance = TestProperties() - app_context.singleton_properties_instance_container["test_properties"] = ( + app_context.container_manager.singleton_properties_instance_container["test_properties"] = ( properties_instance ) retrieved_properties = app_context.get_properties(TestProperties) diff --git a/tests/test_component_features.py b/tests/test_component_features.py index 3aa76b2..9928e9d 100644 --- a/tests/test_component_features.py +++ b/tests/test_component_features.py @@ -92,11 +92,11 @@ def test_duplicate_component_registration(self, app_context: ApplicationContext) This test verifies that: 1. A component can only be registered once - 2. Attempting to register the same component again raises an error - 3. The error message clearly indicates the duplicate registration + 2. Attempting to register the same component again doesn't raise an error (silent skip) + 3. The component is only registered once in the container The test attempts to register the same component twice and verifies - that an appropriate error is raised. + that it's handled gracefully without errors. """ # Define a component @@ -110,15 +110,15 @@ def process(self) -> str: # Register component first time app_context.register_component(TestService) - app_context.init_ioc_container() - - # Attempt to register same component again should raise error - with pytest.raises( - ValueError, - match="\\[COMPONENT REGISTRATION ERROR\\] Component: TestService already registered", - ): - app_context.register_component(TestService) - app_context.init_ioc_container() + initial_count = len(app_context.container_manager.component_cls_container) + + # Register same component again - should be silently skipped + app_context.register_component(TestService) + final_count = len(app_context.container_manager.component_cls_container) + + # Verify component count didn't change (no duplicate registration) + assert final_count == initial_count + assert "TestService" in app_context.container_manager.component_cls_container def test_component_name_override(self, app_context: ApplicationContext): """ @@ -147,8 +147,8 @@ def process(self) -> str: app_context.init_ioc_container() # Verify component is registered with custom name - assert "CustomServiceName" in app_context.component_cls_container - assert app_context.component_cls_container["CustomServiceName"] == TestService + assert "CustomServiceName" in app_context.container_manager.component_cls_container + assert app_context.container_manager.component_cls_container["CustomServiceName"] == TestService def test_qualifier_with_invalid_component(self, app_context: ApplicationContext): """ From 4ea85cd49157198893681717379d8032eee75b20 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Fri, 18 Jul 2025 16:05:14 +0800 Subject: [PATCH 25/42] Add Middleware Support to PySpring Framework (#14) --- py_spring_core/__init__.py | 6 +- .../context/application_context.py | 8 +- .../core/application/py_spring_application.py | 25 +- .../entities/controllers/rest_controller.py | 5 +- .../core/entities/middlewares/middleware.py | 40 ++ .../middlewares/middleware_registry.py | 57 +++ .../interfaces/single_inheritance_required.py | 34 ++ pyproject.toml | 1 + tests/test_middleware.py | 389 ++++++++++++++++++ 9 files changed, 555 insertions(+), 10 deletions(-) create mode 100644 py_spring_core/core/entities/middlewares/middleware.py create mode 100644 py_spring_core/core/entities/middlewares/middleware_registry.py create mode 100644 py_spring_core/core/interfaces/single_inheritance_required.py create mode 100644 tests/test_middleware.py diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 2340789..5e01580 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -10,13 +10,15 @@ PutMapping, ) from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent from py_spring_core.event.application_event_handler_registry import EventListener -__version__ = "0.0.18" +__version__ = "0.0.19" __all__ = [ "PySpringApplication", @@ -35,4 +37,6 @@ "ApplicationEventPublisher", "ApplicationEvent", "EventListener", + "Middleware", + "MiddlewareRegistry", ] \ No newline at end of file diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index 729c528..d1e0e50 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -4,6 +4,7 @@ from inspect import isclass from typing import ( Annotated, + Any, Callable, Mapping, Optional, @@ -25,6 +26,7 @@ from py_spring_core.core.entities.bean_collection import ( BeanCollection, BeanConflictError, + BeanView, InvalidBeanError, ) from py_spring_core.core.entities.component import Component, ComponentScope @@ -362,7 +364,7 @@ def _inject_bean_collection_dependencies(self, bean_collection_cls: Type[BeanCol ) self.dependency_injector.inject_dependencies(bean_collection_cls) - def _validate_bean_view(self, view, collection_name: str) -> None: + def _validate_bean_view(self, view: BeanView, collection_name: str) -> None: """Validate a bean view before adding it to the container.""" if view.bean_name in self.container_manager.singleton_bean_instance_container: raise BeanConflictError( @@ -564,6 +566,10 @@ def init_ioc_container(self) -> None: # Initialize singleton beans self.bean_manager.init_singleton_beans() + def inject_dependencies_for_external_object(self, object: Type[Any]) -> None: + """Inject dependencies for an external object.""" + self.dependency_injector.inject_dependencies(object) + def inject_dependencies_for_app_entities(self) -> None: """Inject dependencies for all registered app entities.""" containers: list[Mapping[str, Type[AppEntities]]] = [ diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 3ceda7c..391bb57 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -28,6 +28,7 @@ from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.controllers.route_mapping import RouteMapping from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired from py_spring_core.event.application_event_handler_registry import ApplicationEventHandlerRegistry @@ -212,8 +213,25 @@ def __init_controllers(self) -> None: controller._register_decorated_routes(routes) router = controller.get_router() self.fastapi.include_router(router) - controller.register_middlewares() - + self.__init_middlewares() + logger.debug(f"[CONTROLLER INIT] Controller {name} initialized") + def __init_middlewares(self) -> None: + logger.debug("[MIDDLEWARE INIT] Initialize middlewares...") + self_defined_registry_cls = MiddlewareRegistry.get_subclass() + if self_defined_registry_cls is None: + logger.debug("[MIDDLEWARE INIT] No self defined registry class found") + return + logger.debug(f"[MIDDLEWARE INIT] Self defined registry class: {self_defined_registry_cls.__name__}") + logger.debug(f"[MIDDLEWARE INIT] Inject dependencies for external object: {self_defined_registry_cls.__name__}") + self.app_context.inject_dependencies_for_external_object(self_defined_registry_cls) + registry = self_defined_registry_cls() + + middleware_classes = registry.get_middleware_classes() + for middleware_class in middleware_classes: + logger.debug(f"[MIDDLEWARE INIT] Inject dependencies for middleware: {middleware_class.__name__}") + self.app_context.inject_dependencies_for_external_object(middleware_class) + registry.apply_middlewares(self.fastapi) + logger.debug("[MIDDLEWARE INIT] Middlewares initialized") def __configure_uvicorn_logging(self): """Configure Uvicorn to use Loguru instead of default logging.""" # Configure Uvicorn to use Loguru @@ -239,9 +257,6 @@ def emit(self, record): logging.basicConfig(handlers=[InterceptHandler()], level=log_level, force=True) def __run_server(self) -> None: - - - # Run uvicorn server uvicorn.run( self.fastapi, diff --git a/py_spring_core/core/entities/controllers/rest_controller.py b/py_spring_core/core/entities/controllers/rest_controller.py index 52ce6a5..93e8933 100644 --- a/py_spring_core/core/entities/controllers/rest_controller.py +++ b/py_spring_core/core/entities/controllers/rest_controller.py @@ -3,6 +3,7 @@ from functools import partial from py_spring_core.core.entities.controllers.route_mapping import RouteRegistration +from py_spring_core.core.entities.middlewares.middleware import Middleware class RestController: @@ -15,7 +16,7 @@ class RestController: - Providing access to the FastAPI `APIRouter` and `FastAPI` app instances - Exposing the controller's configuration, including the URL prefix - Subclasses of `RestController` should override the `register_routes` and `register_middlewares` methods to add their own routes and middleware to the controller. + Subclasses of `RestController` should override the `register_routes` methods to add their own routes and middleware to the controller. """ app: FastAPI @@ -53,8 +54,6 @@ def _register_decorated_routes(self, routes: Iterable[RouteRegistration]) -> Non name=route.name, ) - def register_middlewares(self) -> None: ... - def get_router(self) -> APIRouter: return self.router diff --git a/py_spring_core/core/entities/middlewares/middleware.py b/py_spring_core/core/entities/middlewares/middleware.py new file mode 100644 index 0000000..7eab8ac --- /dev/null +++ b/py_spring_core/core/entities/middlewares/middleware.py @@ -0,0 +1,40 @@ +from abc import abstractmethod +from typing import Awaitable, Callable +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + + + +class Middleware(BaseHTTPMiddleware): + """ + Middleware base class, inherits from FastAPI's BaseHTTPMiddleware + Simpler to use, only need to implement the process_request method + """ + + @abstractmethod + async def process_request(self, request: Request) -> Response | None: + """ + Method to process requests + + Args: + request: FastAPI request object + + Returns: + Response | None: If Response is returned, it will be directly returned to the client + If None is returned, continue to execute the next middleware or route handler + """ + pass + + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + """ + Middleware dispatch method, automatically called by FastAPI + """ + # First execute custom request processing logic + response = await self.process_request(request) + + # If a response is returned, return it directly + if response is not None: + return response + + # Otherwise continue to execute the next middleware or route handler + return await call_next(request) \ No newline at end of file diff --git a/py_spring_core/core/entities/middlewares/middleware_registry.py b/py_spring_core/core/entities/middlewares/middleware_registry.py new file mode 100644 index 0000000..ecff80c --- /dev/null +++ b/py_spring_core/core/entities/middlewares/middleware_registry.py @@ -0,0 +1,57 @@ + + +from abc import ABC, abstractmethod +from typing import Type +from fastapi import FastAPI +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.interfaces.single_inheritance_required import SingleInheritanceRequired + + + + +class MiddlewareRegistry(SingleInheritanceRequired["MiddlewareRegistry"], ABC): + """ + Middleware registry for managing all middlewares + + This registry pattern eliminates the need for manual middleware registration. + The framework automatically handles middleware registration and execution order. + + Multiple middleware execution order: + When multiple middlewares are registered through this registry, they are automatically + applied to the FastAPI application in the order they are returned by get_middleware_classes(). + Each middleware wraps the application, forming a stack. The last middleware added is the outermost, + and the first is the innermost. + + On the request path, the outermost middleware runs first. + On the response path, it runs last. + + For example, if get_middleware_classes() returns [MiddlewareA, MiddlewareB] + This results in the following execution order: + Request: MiddlewareB → MiddlewareA → route + Response: route → MiddlewareA → MiddlewareB + This stacking behavior ensures that middlewares are executed in a predictable and controllable order. + """ + + @abstractmethod + def get_middleware_classes(self) -> list[Type[Middleware]]: + """ + Get all registered middleware classes + + Returns: + List[Type[Middleware]]: List of middleware classes + """ + pass + + def apply_middlewares(self, app: FastAPI) -> FastAPI: + """ + Apply middlewares to FastAPI application + + Args: + app: FastAPI application instance + + Returns: + FastAPI: FastAPI instance with applied middlewares + """ + for middleware_class in self.get_middleware_classes(): + app.add_middleware(middleware_class) + return app \ No newline at end of file diff --git a/py_spring_core/core/interfaces/single_inheritance_required.py b/py_spring_core/core/interfaces/single_inheritance_required.py new file mode 100644 index 0000000..03150be --- /dev/null +++ b/py_spring_core/core/interfaces/single_inheritance_required.py @@ -0,0 +1,34 @@ + + +from abc import ABC +from typing import Generic, Optional, Type, TypeVar, cast + +T = TypeVar('T') + +class SingleInheritanceRequired(Generic[T], ABC): + """ + A singleton component is a component that only allow subclasses to be inherited. + """ + + @classmethod + def check_only_one_subclass_allowed(cls) -> None: + """ + Check if the subclass is allowed to be inherited. + """ + class_dict: dict[str, Type[SingleInheritanceRequired[T]]] = {} + for subclass in cls.__subclasses__(): + if subclass.__name__ in class_dict: + continue + class_dict[subclass.__name__] = subclass + if len(class_dict) > 1: + raise ValueError(f"Only one subclass is allowed for {cls.__name__}, but {len(class_dict)} subclasses: {[subclass.__name__ for subclass in class_dict.values()]} found") + + @classmethod + def get_subclass(cls) -> Optional[Type[T]]: + """ + Get the subclass of the component. + """ + cls.check_only_one_subclass_allowed() + if len(cls.__subclasses__()) == 0: + return + return cast(Type[T], cls.__subclasses__()[0]) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6506f02..f896924 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ dev = [ "isort>=5.13.2", "pytest>=8.3.2", "pytest-mock>=3.14.0", + "pytest-asyncio>=1.1.0", "types-PyYAML>=6.0.12.20240917", "types-cachetools>=5.5.0.20240820", "mypy>=1.11.2" diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..bb5f368 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,389 @@ +import pytest +from unittest.mock import Mock, AsyncMock, patch +from fastapi import FastAPI, Request, Response +from fastapi.testclient import TestClient +from starlette.middleware.base import BaseHTTPMiddleware + +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry + + +class TestMiddleware: + """Test suite for the Middleware base class.""" + + @pytest.fixture + def mock_request(self): + """Fixture that provides a mock FastAPI request.""" + request = Mock(spec=Request) + request.method = "GET" + request.url = "http://test.com/api" + return request + + @pytest.fixture + def mock_call_next(self): + """Fixture that provides a mock call_next function.""" + return AsyncMock() + + def test_middleware_inherits_from_base_http_middleware(self): + """ + Test that Middleware class inherits from BaseHTTPMiddleware. + + This test verifies that: + 1. Middleware is a subclass of BaseHTTPMiddleware + 2. The inheritance relationship is correctly established + """ + assert issubclass(Middleware, BaseHTTPMiddleware) + + def test_middleware_is_abstract(self): + """ + Test that Middleware class is abstract and cannot be instantiated directly. + + This test verifies that: + 1. Middleware is an abstract base class + 2. Attempting to instantiate it directly raises an error + """ + # Test that Middleware is abstract by checking it has abstract methods + assert hasattr(Middleware, 'process_request') + assert Middleware.process_request.__isabstractmethod__ + + def test_process_request_is_abstract(self): + """ + Test that process_request method is abstract and must be implemented. + + This test verifies that: + 1. process_request is an abstract method + 2. Subclasses must implement this method + """ + # Create a concrete subclass without implementing process_request + class ConcreteMiddleware(Middleware): + pass + + # Test that the class is abstract by checking it has abstract methods + assert hasattr(ConcreteMiddleware, 'process_request') + # The method should still be abstract since it wasn't implemented + assert ConcreteMiddleware.process_request.__isabstractmethod__ + + @pytest.mark.asyncio + async def test_dispatch_continues_when_process_request_returns_none(self, mock_request, mock_call_next): + """ + Test that dispatch continues to next middleware when process_request returns None. + + This test verifies that: + 1. When process_request returns None, dispatch continues to call_next + 2. The call_next function is called with the correct request + 3. The response from call_next is returned + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = TestMiddleware(app=Mock()) + result = await middleware.dispatch(mock_request, mock_call_next) + + mock_call_next.assert_called_once_with(mock_request) + assert result == expected_response + + @pytest.mark.asyncio + async def test_dispatch_returns_response_when_process_request_returns_response(self, mock_request, mock_call_next): + """ + Test that dispatch returns response directly when process_request returns a response. + + This test verifies that: + 1. When process_request returns a Response, dispatch returns it directly + 2. call_next is not called when process_request returns a response + 3. The response from process_request is returned unchanged + """ + middleware_response = Response(content="middleware response", status_code=403) + + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return middleware_response + + middleware = TestMiddleware(app=Mock()) + result = await middleware.dispatch(mock_request, mock_call_next) + + mock_call_next.assert_not_called() + assert result == middleware_response + + @pytest.mark.asyncio + async def test_dispatch_passes_request_to_process_request(self, mock_request, mock_call_next): + """ + Test that dispatch passes the request to process_request method. + + This test verifies that: + 1. The request object is correctly passed to process_request + 2. The process_request method receives the exact same request object + """ + received_request = None + + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + nonlocal received_request + received_request = request + return None + + middleware = TestMiddleware(app=Mock()) + await middleware.dispatch(mock_request, mock_call_next) + + assert received_request == mock_request + + +class TestMiddlewareRegistry: + """Test suite for the MiddlewareRegistry abstract class.""" + + @pytest.fixture + def fastapi_app(self): + """Fixture that provides a fresh FastAPI application instance.""" + return FastAPI() + + def test_middleware_registry_is_abstract(self): + """ + Test that MiddlewareRegistry class is abstract and cannot be instantiated directly. + + This test verifies that: + 1. MiddlewareRegistry is an abstract base class + 2. Attempting to instantiate it directly raises an error + """ + # This test verifies that MiddlewareRegistry is abstract + # We can't test direct instantiation because it's abstract + # Instead, we test that it has the abstract method + assert hasattr(MiddlewareRegistry, 'get_middleware_classes') + assert MiddlewareRegistry.get_middleware_classes.__isabstractmethod__ + + def test_get_middleware_classes_is_abstract(self): + """ + Test that get_middleware_classes method is abstract and must be implemented. + + This test verifies that: + 1. get_middleware_classes is an abstract method + 2. Subclasses must implement this method + """ + # Create a concrete subclass without implementing get_middleware_classes + class ConcreteRegistry(MiddlewareRegistry): # type: ignore[abstract] + pass + + with pytest.raises(TypeError): + ConcreteRegistry() # type: ignore[abstract] + + def test_apply_middlewares_adds_middleware_to_app(self, fastapi_app): + """ + Test that apply_middlewares correctly adds middleware classes to FastAPI app. + + This test verifies that: + 1. Middleware classes are added to the FastAPI application + 2. The add_middleware method is called for each middleware class + 3. The app is returned unchanged + """ + class TestMiddleware1(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestMiddleware2(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [TestMiddleware1, TestMiddleware2] + + # Mock the add_middleware method + with patch.object(fastapi_app, 'add_middleware') as mock_add_middleware: + registry = TestRegistry() + result = registry.apply_middlewares(fastapi_app) + + # Verify add_middleware was called for each middleware class + assert mock_add_middleware.call_count == 2 + mock_add_middleware.assert_any_call(TestMiddleware1) + mock_add_middleware.assert_any_call(TestMiddleware2) + + # Verify the app is returned + assert result == fastapi_app + + def test_apply_middlewares_with_empty_list(self, fastapi_app): + """ + Test that apply_middlewares handles empty middleware list correctly. + + This test verifies that: + 1. When no middlewares are registered, no middleware is added + 2. The app is returned unchanged + 3. No errors occur with empty middleware list + """ + class EmptyRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [] + + with patch.object(fastapi_app, 'add_middleware') as mock_add_middleware: + registry = EmptyRegistry() + result = registry.apply_middlewares(fastapi_app) + + # Verify add_middleware was not called + mock_add_middleware.assert_not_called() + + # Verify the app is returned + assert result == fastapi_app + + def test_apply_middlewares_preserves_app_state(self, fastapi_app): + """ + Test that apply_middlewares preserves the FastAPI app state. + + This test verifies that: + 1. The original app object is returned (same reference) + 2. No app properties are modified during middleware application + """ + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [TestMiddleware] + + # Store original app state + original_app_id = id(fastapi_app) + + registry = TestRegistry() + result = registry.apply_middlewares(fastapi_app) + + # Verify same app object is returned + assert id(result) == original_app_id + assert result is fastapi_app + + +class TestMiddlewareIntegration: + """Integration tests for middleware functionality.""" + + @pytest.fixture + def fastapi_app(self): + """Fixture that provides a fresh FastAPI application instance.""" + return FastAPI() + + @pytest.mark.asyncio + async def test_middleware_chain_execution(self, fastapi_app): + """ + Test that multiple middlewares execute in the correct order. + + This test verifies that: + 1. Middlewares are executed in the order they are added + 2. Each middleware can process the request + 3. The chain continues correctly when middlewares return None + """ + execution_order = [] + + class FirstMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("first") + return None + + class SecondMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("second") + return None + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [FirstMiddleware, SecondMiddleware] + + registry = TestRegistry() + app = registry.apply_middlewares(fastapi_app) + + # Create a test client to trigger middleware execution + from fastapi.testclient import TestClient + + @app.get("/test") + async def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + # Verify middlewares were executed in order (FastAPI uses LIFO - Last In, First Out) + assert execution_order == ["second", "first"] + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_middleware_early_return(self, fastapi_app): + """ + Test that middleware can return early and prevent further execution. + + This test verifies that: + 1. When a middleware returns a response, subsequent middlewares are not executed + 2. The route handler is not called when middleware returns early + 3. The response from the middleware is returned to the client + """ + execution_order = [] + + class BlockingMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("blocking") + return Response(content="blocked", status_code=403) + + class SecondMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("second") + return None + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [BlockingMiddleware, SecondMiddleware] + + registry = TestRegistry() + app = registry.apply_middlewares(fastapi_app) + + @app.get("/test") + async def test_endpoint(): + execution_order.append("handler") + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + # Verify only blocking middleware executed (FastAPI uses LIFO - Last In, First Out) + # SecondMiddleware executes first, then BlockingMiddleware returns early + assert execution_order == ["second", "blocking"] + assert response.status_code == 403 + assert response.text == "blocked" + + def test_middleware_registry_single_inheritance(self): + """ + Test that MiddlewareRegistry enforces single inheritance. + + This test verifies that: + 1. MiddlewareRegistry implements SingleInheritanceRequired + 2. Multiple inheritance is prevented + """ + # This test assumes SingleInheritanceRequired prevents multiple inheritance + # The actual behavior depends on the implementation of SingleInheritanceRequired + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [] + + # Should be able to create a single inheritance registry + registry = TestRegistry() + assert isinstance(registry, MiddlewareRegistry) + + def test_middleware_type_hints(self): + """ + Test that middleware classes have correct type hints. + + This test verifies that: + 1. get_middleware_classes returns the correct type + 2. process_request has correct parameter and return type hints + """ + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [TestMiddleware] + + registry = TestRegistry() + middleware_classes = registry.get_middleware_classes() + + # Verify type hints + assert isinstance(middleware_classes, list) + assert all(issubclass(middleware_class, Middleware) for middleware_class in middleware_classes) \ No newline at end of file From dd318196f61412e66e1c7a75c346d25656d5878a Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 18 Jul 2025 16:18:19 +0800 Subject: [PATCH 26/42] [skip ci] Update CI conditions in PyPI deployment workflow to skip builds based on commit messages and pull request titles --- .github/workflows/pypi-deployment.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pypi-deployment.yaml b/.github/workflows/pypi-deployment.yaml index 8ec2e1f..8b61d92 100644 --- a/.github/workflows/pypi-deployment.yaml +++ b/.github/workflows/pypi-deployment.yaml @@ -8,6 +8,7 @@ on: jobs: publish: runs-on: ubuntu-latest + if: ${{ !contains(github.event.head_commit.message, 'skip ci') && !contains(github.event.pull_request.title, 'skip ci') && (github.ref == 'refs/heads/main' || github.ref_type == 'tag') }} steps: - name: Checkout code From 43068bd1b8abf4fdad862e9e79e477fb17e928db Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 18 Jul 2025 16:19:37 +0800 Subject: [PATCH 27/42] [skip ci] ruff & isort formatting --- py_spring_core/__init__.py | 27 +- .../config_file_template_generator.py | 7 +- py_spring_core/core/application/commons.py | 3 +- .../context/application_context.py | 397 +++++++++++------- .../core/application/loguru_config.py | 1 + .../core/application/py_spring_application.py | 94 +++-- .../entities/controllers/rest_controller.py | 6 +- .../core/entities/entity_provider.py | 8 +- .../core/entities/middlewares/middleware.py | 20 +- .../middlewares/middleware_registry.py | 35 +- .../application_context_required.py | 13 +- .../interfaces/single_inheritance_required.py | 13 +- py_spring_core/core/utils.py | 12 +- .../application_event_handler_registry.py | 43 +- .../event/application_event_publisher.py | 14 +- py_spring_core/event/commons.py | 5 +- tests/test_application_context.py | 57 ++- tests/test_bean_collection.py | 3 +- tests/test_component_features.py | 19 +- tests/test_entity_provider.py | 17 +- tests/test_middleware.py | 79 ++-- tests/test_properties_loader.py | 4 +- 22 files changed, 517 insertions(+), 360 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 5e01580..903c6f7 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -1,22 +1,23 @@ -from py_spring_core.core.application.py_spring_application import PySpringApplication +from py_spring_core.core.application.py_spring_application import \ + PySpringApplication from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component, ComponentScope -from py_spring_core.core.entities.controllers.rest_controller import RestController +from py_spring_core.core.entities.controllers.rest_controller import \ + RestController from py_spring_core.core.entities.controllers.route_mapping import ( - DeleteMapping, - GetMapping, - PatchMapping, - PostMapping, - PutMapping, -) + DeleteMapping, GetMapping, PatchMapping, PostMapping, PutMapping) from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.core.entities.middlewares.middleware import Middleware -from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry +from py_spring_core.core.entities.middlewares.middleware_registry import \ + MiddlewareRegistry from py_spring_core.core.entities.properties.properties import Properties -from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired -from py_spring_core.event.application_event_publisher import ApplicationEventPublisher +from py_spring_core.core.interfaces.application_context_required import \ + ApplicationContextRequired +from py_spring_core.event.application_event_handler_registry import \ + EventListener +from py_spring_core.event.application_event_publisher import \ + ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent -from py_spring_core.event.application_event_handler_registry import EventListener __version__ = "0.0.19" @@ -39,4 +40,4 @@ "EventListener", "Middleware", "MiddlewareRegistry", -] \ No newline at end of file +] diff --git a/py_spring_core/commons/config_file_template_generator/config_file_template_generator.py b/py_spring_core/commons/config_file_template_generator/config_file_template_generator.py index 796c1b0..96b85df 100644 --- a/py_spring_core/commons/config_file_template_generator/config_file_template_generator.py +++ b/py_spring_core/commons/config_file_template_generator/config_file_template_generator.py @@ -6,10 +6,9 @@ from pydantic import BaseModel from py_spring_core.commons.config_file_template_generator.templates import ( - app_config_template, - app_properties_template, -) -from py_spring_core.core.application.application_config import ApplicationConfig + app_config_template, app_properties_template) +from py_spring_core.core.application.application_config import \ + ApplicationConfig class ConfigFileTemplateGenerator: diff --git a/py_spring_core/core/application/commons.py b/py_spring_core/core/application/commons.py index eb0b11b..6b4e9a0 100644 --- a/py_spring_core/core/application/commons.py +++ b/py_spring_core/core/application/commons.py @@ -1,6 +1,7 @@ from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component -from py_spring_core.core.entities.controllers.rest_controller import RestController +from py_spring_core.core.entities.controllers.rest_controller import \ + RestController from py_spring_core.core.entities.properties.properties import Properties AppEntities = Component | RestController | BeanCollection | Properties diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index d1e0e50..3d8b043 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -1,39 +1,27 @@ -import py_spring_core.core.utils as framework_utils - from abc import ABC from inspect import isclass -from typing import ( - Annotated, - Any, - Callable, - Mapping, - Optional, - Type, - TypeVar, - cast, - get_args, - get_origin, -) +from typing import (Annotated, Any, Callable, Mapping, Optional, Type, TypeVar, + cast, get_args, get_origin) from fastapi import FastAPI from loguru import logger from pydantic import BaseModel +import py_spring_core.core.utils as framework_utils from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.application.context.application_context_config import ( - ApplicationContextConfig, -) -from py_spring_core.core.entities.bean_collection import ( - BeanCollection, - BeanConflictError, - BeanView, - InvalidBeanError, -) +from py_spring_core.core.application.context.application_context_config import \ + ApplicationContextConfig +from py_spring_core.core.entities.bean_collection import (BeanCollection, + BeanConflictError, + BeanView, + InvalidBeanError) from py_spring_core.core.entities.component import Component, ComponentScope -from py_spring_core.core.entities.controllers.rest_controller import RestController +from py_spring_core.core.entities.controllers.rest_controller import \ + RestController from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.core.entities.properties.properties import Properties -from py_spring_core.core.entities.properties.properties_loader import _PropertiesLoader +from py_spring_core.core.entities.properties.properties_loader import \ + _PropertiesLoader T = TypeVar("T", bound=AppEntities) PT = TypeVar("PT", bound=Properties) @@ -41,16 +29,19 @@ class ComponentNotFoundError(Exception): """Raised when a component is not found in the application context.""" + pass class InvalidDependencyError(Exception): """Raised when a dependency is invalid or not found in the application context.""" + pass class ApplicationContextView(BaseModel): """View model for application context state.""" + config: ApplicationContextConfig component_cls_container: list[str] singleton_component_instance_container: list[str] @@ -58,18 +49,18 @@ class ApplicationContextView(BaseModel): class ContainerManager: """Manages containers for different types of entities in the application context.""" - + def __init__(self): self.component_cls_container: dict[str, Type[Component]] = {} self.controller_cls_container: dict[str, Type[RestController]] = {} self.singleton_component_instance_container: dict[str, Component] = {} - + self.bean_collection_cls_container: dict[str, Type[BeanCollection]] = {} self.singleton_bean_instance_container: dict[str, object] = {} - + self.properties_cls_container: dict[str, Type[Properties]] = {} self.singleton_properties_instance_container: dict[str, Properties] = {} - + def is_entity_in_container(self, entity_cls: Type[AppEntities]) -> bool: """Check if an entity class is registered in any container.""" cls_name = entity_cls.__name__ @@ -83,13 +74,15 @@ def is_entity_in_container(self, entity_cls: Type[AppEntities]) -> bool: class DependencyInjector: """Handles dependency injection for entities in the application context.""" - + def __init__(self, container_manager: ContainerManager): self.container_manager = container_manager self.primitive_types = (bool, str, int, float, type(None)) - self._app_context: Optional['ApplicationContext'] = None - - def _extract_qualifier_from_annotation(self, annotated_type: Type) -> tuple[Type, Optional[str]]: + self._app_context: Optional["ApplicationContext"] = None + + def _extract_qualifier_from_annotation( + self, annotated_type: Type + ) -> tuple[Type, Optional[str]]: """Extract the actual type and qualifier from an Annotated type.""" qualifier = None if get_origin(annotated_type) is Annotated: @@ -98,17 +91,17 @@ def _extract_qualifier_from_annotation(self, annotated_type: Type) -> tuple[Type if len(type_args) > 1: qualifier = type_args[1] return annotated_type, qualifier - + def _inject_properties_dependency( - self, - entity: Type[AppEntities], - attr_name: str, - properties_cls: Type[Properties] + self, + entity: Type[AppEntities], + attr_name: str, + properties_cls: Type[Properties], ) -> bool: """Inject a properties dependency into an entity.""" if self._app_context is None: return False - + optional_properties = self._app_context.get_properties(properties_cls) if optional_properties is None: raise TypeError( @@ -118,23 +111,22 @@ def _inject_properties_dependency( ) setattr(entity, attr_name, optional_properties) return True - + def _try_inject_entity_dependency( - self, - entity: Type[AppEntities], - attr_name: str, - entity_cls: Type[AppEntities], - qualifier: Optional[str] + self, + entity: Type[AppEntities], + attr_name: str, + entity_cls: Type[AppEntities], + qualifier: Optional[str], ) -> bool: """Try to inject an entity dependency using available getters.""" if self._app_context is None: return False - - entity_getters: list[Callable[[Type[AppEntities], Optional[str]], Optional[AppEntities]]] = [ - self._app_context.get_component, - self._app_context.get_bean - ] - + + entity_getters: list[ + Callable[[Type[AppEntities], Optional[str]], Optional[AppEntities]] + ] = [self._app_context.get_component, self._app_context.get_bean] + for getter in entity_getters: optional_entity = getter(entity_cls, qualifier) if optional_entity is not None: @@ -145,12 +137,14 @@ def _try_inject_entity_dependency( ) return True return False - + def inject_dependencies(self, entity: Type[AppEntities]) -> None: """Inject dependencies for a given entity based on its annotations.""" for attr_name, annotated_type in entity.__annotations__.items(): - entity_cls, qualifier = self._extract_qualifier_from_annotation(annotated_type) - + entity_cls, qualifier = self._extract_qualifier_from_annotation( + annotated_type + ) + # Skip primitive types if entity_cls in self.primitive_types: logger.warning( @@ -158,20 +152,22 @@ def inject_dependencies(self, entity: Type[AppEntities]) -> None: f"with dependency: {entity_cls.__name__} because it is primitive type" ) continue - + # Skip non-class types if not isclass(entity_cls): continue - + # Handle Properties injection if issubclass(entity_cls, Properties): if self._inject_properties_dependency(entity, attr_name, entity_cls): continue - + # Try to inject entity dependency - if self._try_inject_entity_dependency(entity, attr_name, entity_cls, qualifier): + if self._try_inject_entity_dependency( + entity, attr_name, entity_cls, qualifier + ): continue - + # If we get here, injection failed error_message = ( f"[DEPENDENCY INJECTION FAILED] Fail to inject dependency for attribute: {attr_name} " @@ -184,47 +180,47 @@ def inject_dependencies(self, entity: Type[AppEntities]) -> None: class ComponentManager: """Manages component registration and instantiation.""" - + def __init__(self, container_manager: ContainerManager): self.container_manager = container_manager - + def _determine_target_cls_name( self, component_cls: Type[T], qualifier: Optional[str] ) -> str: """ Determine the target class name for a given component class. - + Args: component_cls: The component class to determine the name for qualifier: Optional qualifier to use directly - + Returns: The target class name - + Raises: ValueError: If abstract class has no subclasses """ if qualifier is not None: return qualifier - + # If it's not an ABC, return its name directly if not issubclass(component_cls, ABC): return component_cls.get_name() - + # If it's an ABC but has implementations, return its name directly if not component_cls.__abstractmethods__: return component_cls.get_name() - + # For abstract classes that need implementations subclasses = component_cls.__subclasses__() if len(subclasses) == 0: raise ValueError( f"[ABSTRACT CLASS ERROR] Abstract class {component_cls.__name__} has no subclasses" ) - + # Fall back to first subclass if no primary component exists return subclasses[0].get_name() - + def register_component(self, component_cls: Type[Component]) -> None: """Register a component class in the container.""" if not issubclass(component_cls, Component): @@ -232,35 +228,48 @@ def register_component(self, component_cls: Type[Component]) -> None: f"[COMPONENT REGISTRATION ERROR] Component: {component_cls} " f"is not a subclass of Component" ) - + component_cls_name = component_cls.get_name() - existing_component = self.container_manager.component_cls_container.get(component_cls_name) - + existing_component = self.container_manager.component_cls_container.get( + component_cls_name + ) + # Check if it's the same component to avoid duplicate registration - if (existing_component and - existing_component.__name__ == component_cls.__name__ and - existing_component == component_cls): + if ( + existing_component + and existing_component.__name__ == component_cls.__name__ + and existing_component == component_cls + ): return - - self.container_manager.component_cls_container[component_cls_name] = component_cls - - def get_component(self, component_cls: Type[T], qualifier: Optional[str]) -> Optional[T]: + + self.container_manager.component_cls_container[component_cls_name] = ( + component_cls + ) + + def get_component( + self, component_cls: Type[T], qualifier: Optional[str] + ) -> Optional[T]: """Get a component instance by class and optional qualifier.""" if not issubclass(component_cls, (Component, ABC)): return None - + target_cls_name = self._determine_target_cls_name(component_cls, qualifier) - + if target_cls_name not in self.container_manager.component_cls_container: return None - + scope = component_cls.get_scope() match scope: case ComponentScope.Singleton: - return cast(T, self.container_manager.singleton_component_instance_container.get(target_cls_name)) + return cast( + T, + self.container_manager.singleton_component_instance_container.get( + target_cls_name + ), + ) case ComponentScope.Prototype: return cast(T, component_cls()) - + def _init_singleton_component( self, component_cls: Type[Component], component_cls_name: str ) -> Optional[Component]: @@ -279,23 +288,28 @@ def _init_singleton_component( f"{component_cls_name} with error: {error}" ) raise error - - def _get_abstract_class_component_subclasses(self, component_cls: Type[ABC]) -> list[Type[Component]]: + + def _get_abstract_class_component_subclasses( + self, component_cls: Type[ABC] + ) -> list[Type[Component]]: """Get all Component subclasses of an abstract class.""" return [ - subclass for subclass in component_cls.__subclasses__() + subclass + for subclass in component_cls.__subclasses__() if issubclass(subclass, Component) ] - + def _init_abstract_component_subclasses(self, component_cls: Type[ABC]) -> None: """Initialize singleton instances for abstract component subclasses.""" component_classes = self._get_abstract_class_component_subclasses(component_cls) - + for subclass_component_cls in component_classes: self.register_component(subclass_component_cls) - + # Check for unimplemented abstract methods - unimplemented_methods = framework_utils.get_unimplemented_abstract_methods(subclass_component_cls) + unimplemented_methods = framework_utils.get_unimplemented_abstract_methods( + subclass_component_cls + ) if unimplemented_methods: methods_str = ", ".join(unimplemented_methods) message = ( @@ -305,39 +319,56 @@ def _init_abstract_component_subclasses(self, component_cls: Type[ABC]) -> None: ) logger.error(message) raise ValueError(message) - + logger.debug( f"[ABSTRACT CLASS COMPONENT INITIALIZING SINGLETON COMPONENT] " f"Init singleton component: {subclass_component_cls.get_name()}" ) - - instance = self._init_singleton_component(subclass_component_cls, subclass_component_cls.get_name()) + + instance = self._init_singleton_component( + subclass_component_cls, subclass_component_cls.get_name() + ) if instance is not None: - self.container_manager.singleton_component_instance_container[subclass_component_cls.get_name()] = instance - + self.container_manager.singleton_component_instance_container[ + subclass_component_cls.get_name() + ] = instance + def init_singleton_components(self) -> None: """Initialize all singleton components in the container.""" - for component_cls_name, component_cls in self.container_manager.component_cls_container.items(): + for ( + component_cls_name, + component_cls, + ) in self.container_manager.component_cls_container.items(): if component_cls.get_scope() != ComponentScope.Singleton: continue - - logger.debug(f"[INITIALIZING SINGLETON COMPONENT] Init singleton component: {component_cls_name}") - + + logger.debug( + f"[INITIALIZING SINGLETON COMPONENT] Init singleton component: {component_cls_name}" + ) + if issubclass(component_cls, ABC): self._init_abstract_component_subclasses(component_cls) else: - instance = self._init_singleton_component(component_cls, component_cls_name) + instance = self._init_singleton_component( + component_cls, component_cls_name + ) if instance is not None: - self.container_manager.singleton_component_instance_container[component_cls_name] = instance + self.container_manager.singleton_component_instance_container[ + component_cls_name + ] = instance class BeanManager: """Manages bean collection registration and instantiation.""" - - def __init__(self, container_manager: ContainerManager, dependency_injector: DependencyInjector): + + def __init__( + self, + container_manager: ContainerManager, + dependency_injector: DependencyInjector, + ): self.container_manager = container_manager self.dependency_injector = dependency_injector - + def register_bean_collection(self, bean_cls: Type[BeanCollection]) -> None: """Register a bean collection class in the container.""" if not issubclass(bean_cls, BeanCollection): @@ -345,59 +376,74 @@ def register_bean_collection(self, bean_cls: Type[BeanCollection]) -> None: f"[BEAN COLLECTION REGISTRATION ERROR] BeanCollection: {bean_cls} " f"is not a subclass of BeanCollection" ) - + bean_name = bean_cls.get_name() self.container_manager.bean_collection_cls_container[bean_name] = bean_cls - - def get_bean(self, object_cls: Type[T], qualifier: Optional[str] = None) -> Optional[T]: + + def get_bean( + self, object_cls: Type[T], qualifier: Optional[str] = None + ) -> Optional[T]: """Get a bean instance by class and optional qualifier.""" bean_name = object_cls.__name__ if bean_name not in self.container_manager.singleton_bean_instance_container: return None - - return cast(T, self.container_manager.singleton_bean_instance_container.get(bean_name)) - - def _inject_bean_collection_dependencies(self, bean_collection_cls: Type[BeanCollection]) -> None: + + return cast( + T, self.container_manager.singleton_bean_instance_container.get(bean_name) + ) + + def _inject_bean_collection_dependencies( + self, bean_collection_cls: Type[BeanCollection] + ) -> None: """Inject dependencies for a bean collection.""" logger.info( f"[BEAN COLLECTION DEPENDENCY INJECTION] Injecting dependencies for {bean_collection_cls.get_name()}" ) self.dependency_injector.inject_dependencies(bean_collection_cls) - + def _validate_bean_view(self, view: BeanView, collection_name: str) -> None: """Validate a bean view before adding it to the container.""" if view.bean_name in self.container_manager.singleton_bean_instance_container: raise BeanConflictError( f"[BEAN CONFLICTS] Bean: {view.bean_name} already exists under collection: {collection_name}" ) - + if not view.is_valid_bean(): raise InvalidBeanError( f"[INVALID BEAN] Bean name from bean creation func return type: {view.bean_name} " f"does not match the bean object class name: {view.bean.__class__.__name__}" ) - + def init_singleton_beans(self) -> None: """Initialize all singleton beans from registered bean collections.""" - for bean_collection_cls_name, bean_collection_cls in self.container_manager.bean_collection_cls_container.items(): - logger.debug(f"[INITIALIZING SINGLETON BEAN] Init singleton bean: {bean_collection_cls_name}") - + for ( + bean_collection_cls_name, + bean_collection_cls, + ) in self.container_manager.bean_collection_cls_container.items(): + logger.debug( + f"[INITIALIZING SINGLETON BEAN] Init singleton bean: {bean_collection_cls_name}" + ) + collection = bean_collection_cls() self._inject_bean_collection_dependencies(bean_collection_cls) - + bean_views = collection.scan_beans() for view in bean_views: self._validate_bean_view(view, collection.get_name()) - self.container_manager.singleton_bean_instance_container[view.bean_name] = view.bean + self.container_manager.singleton_bean_instance_container[ + view.bean_name + ] = view.bean class PropertiesManager: """Manages properties registration and loading.""" - - def __init__(self, container_manager: ContainerManager, config: ApplicationContextConfig): + + def __init__( + self, container_manager: ContainerManager, config: ApplicationContextConfig + ): self.container_manager = container_manager self.config = config - + def register_properties(self, properties_cls: Type[Properties]) -> None: """Register a properties class in the container.""" if not issubclass(properties_cls, Properties): @@ -405,36 +451,51 @@ def register_properties(self, properties_cls: Type[Properties]) -> None: f"[PROPERTIES REGISTRATION ERROR] Properties: {properties_cls} " f"is not a subclass of Properties" ) - + properties_name = properties_cls.get_key() - self.container_manager.properties_cls_container[properties_name] = properties_cls - + self.container_manager.properties_cls_container[properties_name] = ( + properties_cls + ) + def get_properties(self, properties_cls: Type[PT]) -> Optional[PT]: """Get a properties instance by class.""" properties_cls_name = properties_cls.get_key() if properties_cls_name not in self.container_manager.properties_cls_container: return None - - return cast(PT, self.container_manager.singleton_properties_instance_container.get(properties_cls_name)) - + + return cast( + PT, + self.container_manager.singleton_properties_instance_container.get( + properties_cls_name + ), + ) + def _create_properties_loader(self) -> _PropertiesLoader: """Create a properties loader instance.""" return _PropertiesLoader( - self.config.properties_path, - list(self.container_manager.properties_cls_container.values()) + self.config.properties_path, + list(self.container_manager.properties_cls_container.values()), ) - + def load_properties(self) -> None: """Load all registered properties from configuration files.""" properties_loader = self._create_properties_loader() properties_instance_dict = properties_loader.load_properties() - - for properties_key, properties_cls in self.container_manager.properties_cls_container.items(): - if properties_key in self.container_manager.singleton_properties_instance_container: + + for ( + properties_key, + properties_cls, + ) in self.container_manager.properties_cls_container.items(): + if ( + properties_key + in self.container_manager.singleton_properties_instance_container + ): continue - - logger.debug(f"[INITIALIZING SINGLETON PROPERTIES] Init singleton properties: {properties_key}") - + + logger.debug( + f"[INITIALIZING SINGLETON PROPERTIES] Init singleton properties: {properties_key}" + ) + optional_properties = properties_instance_dict.get(properties_key) if optional_properties is None: raise TypeError( @@ -442,24 +503,28 @@ def load_properties(self) -> None: f"is not found in properties file for class: {properties_cls.get_name()} " f"with key: {properties_cls.get_key()}" ) - - self.container_manager.singleton_properties_instance_container[properties_key] = optional_properties - + + self.container_manager.singleton_properties_instance_container[ + properties_key + ] = optional_properties + # Update the global properties loader reference - _PropertiesLoader.optional_loaded_properties = self.container_manager.singleton_properties_instance_container + _PropertiesLoader.optional_loaded_properties = ( + self.container_manager.singleton_properties_instance_container + ) class ApplicationContext: """ The main entry point for the application's context management. - + This class is responsible for: 1. Registering and managing the lifecycle of components, controllers, bean collections, and properties. 2. Providing methods to retrieve instances of registered components, beans, and properties. 3. Initializing the Inversion of Control (IoC) container by creating singleton instances of registered components. 4. Injecting dependencies for registered components and controllers. - - The ApplicationContext class is designed to follow the Singleton design pattern, ensuring that there is + + The ApplicationContext class is designed to follow the Singleton design pattern, ensuring that there is a single instance of the application context throughout the application's lifetime. """ @@ -468,14 +533,16 @@ def __init__(self, config: ApplicationContextConfig, server: FastAPI) -> None: self.config = config self.all_file_paths: set[str] = set() self.providers: list[EntityProvider] = [] - + # Initialize managers self.container_manager = ContainerManager() self.dependency_injector = DependencyInjector(self.container_manager) self.component_manager = ComponentManager(self.container_manager) - self.bean_manager = BeanManager(self.container_manager, self.dependency_injector) + self.bean_manager = BeanManager( + self.container_manager, self.dependency_injector + ) self.properties_manager = PropertiesManager(self.container_manager, config) - + # Set app context reference for dependency injection self.dependency_injector._app_context = self @@ -487,14 +554,18 @@ def as_view(self) -> ApplicationContextView: """Create a view model of the application context state.""" return ApplicationContextView( config=self.config, - component_cls_container=list(self.container_manager.component_cls_container.keys()), + component_cls_container=list( + self.container_manager.component_cls_container.keys() + ), singleton_component_instance_container=list( self.container_manager.singleton_component_instance_container.keys() ), ) # Component management methods - def get_component(self, component_cls: Type[T], qualifier: Optional[str] = None) -> Optional[T]: + def get_component( + self, component_cls: Type[T], qualifier: Optional[str] = None + ) -> Optional[T]: """Get a component instance by class and optional qualifier.""" return self.component_manager.get_component(component_cls, qualifier) @@ -503,7 +574,9 @@ def register_component(self, component_cls: Type[Component]) -> None: self.component_manager.register_component(component_cls) # Bean management methods - def get_bean(self, object_cls: Type[T], qualifier: Optional[str] = None) -> Optional[T]: + def get_bean( + self, object_cls: Type[T], qualifier: Optional[str] = None + ) -> Optional[T]: """Get a bean instance by class and optional qualifier.""" return self.bean_manager.get_bean(object_cls, qualifier) @@ -532,17 +605,23 @@ def register_controller(self, controller_cls: Type[RestController]) -> None: f"[CONTROLLER REGISTRATION ERROR] Controller: {controller_cls} " f"is not a subclass of RestController" ) - + controller_cls_name = controller_cls.get_name() - self.container_manager.controller_cls_container[controller_cls_name] = controller_cls + self.container_manager.controller_cls_container[controller_cls_name] = ( + controller_cls + ) def get_controller_instances(self) -> list[RestController]: """Get all controller instances.""" - return [cls() for cls in self.container_manager.controller_cls_container.values()] + return [ + cls() for cls in self.container_manager.controller_cls_container.values() + ] def get_singleton_component_instances(self) -> list[Component]: """Get all singleton component instances.""" - return list(self.container_manager.singleton_component_instance_container.values()) + return list( + self.container_manager.singleton_component_instance_container.values() + ) def get_singleton_bean_instances(self) -> list[object]: """Get all singleton bean instances.""" @@ -555,14 +634,14 @@ def is_within_context(self, entity_cls: Type[AppEntities]) -> bool: def init_ioc_container(self) -> None: """ Initialize the IoC (Inversion of Control) container. - + This method creates singleton instances of all registered components and beans, - ensuring that subsequent calls to get_component() for singleton components + ensuring that subsequent calls to get_component() for singleton components will return the same instance, as required by the Singleton design pattern. """ # Initialize singleton components self.component_manager.init_singleton_components() - + # Initialize singleton beans self.bean_manager.init_singleton_beans() @@ -588,7 +667,7 @@ def _validate_entity_provider_dependencies(self, provider: EntityProvider) -> No error = f"[INVALID DEPENDENCY] Invalid dependency {dependency.__name__} in {provider.__class__.__name__}" logger.error(error) raise InvalidDependencyError(error) - + if not self.is_within_context(dependency): error = f"[INVALID DEPENDENCY] Dependency {dependency.__name__} not found in the application context" logger.error(error) diff --git a/py_spring_core/core/application/loguru_config.py b/py_spring_core/core/application/loguru_config.py index d4b16db..3efc541 100644 --- a/py_spring_core/core/application/loguru_config.py +++ b/py_spring_core/core/application/loguru_config.py @@ -18,6 +18,7 @@ class LogFormat(str, Enum): TEXT = "text" JSON = "json" + class LoguruConfig(BaseModel): log_format: str = ( "{time:YYYY-MM-DD HH:mm:ss.SSS} | " diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 391bb57..4f6ba99 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -7,32 +7,34 @@ from loguru import logger from py_spring_core.commons.class_scanner import ClassScanner -from py_spring_core.commons.config_file_template_generator.config_file_template_generator import ( - ConfigFileTemplateGenerator, -) +from py_spring_core.commons.config_file_template_generator.config_file_template_generator import \ + ConfigFileTemplateGenerator from py_spring_core.commons.file_path_scanner import FilePathScanner from py_spring_core.commons.type_checking_service import TypeCheckingService -from py_spring_core.core.application.application_config import ( - ApplicationConfigRepository, -) +from py_spring_core.core.application.application_config import \ + ApplicationConfigRepository from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.application.context.application_context import ( - ApplicationContext, -) -from py_spring_core.core.application.context.application_context_config import ( - ApplicationContextConfig, -) +from py_spring_core.core.application.context.application_context import \ + ApplicationContext +from py_spring_core.core.application.context.application_context_config import \ + ApplicationContextConfig from py_spring_core.core.application.loguru_config import LogFormat from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component, ComponentLifeCycle -from py_spring_core.core.entities.controllers.rest_controller import RestController +from py_spring_core.core.entities.component import (Component, + ComponentLifeCycle) +from py_spring_core.core.entities.controllers.rest_controller import \ + RestController from py_spring_core.core.entities.controllers.route_mapping import RouteMapping from py_spring_core.core.entities.entity_provider import EntityProvider -from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry +from py_spring_core.core.entities.middlewares.middleware_registry import \ + MiddlewareRegistry from py_spring_core.core.entities.properties.properties import Properties -from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired -from py_spring_core.event.application_event_handler_registry import ApplicationEventHandlerRegistry -from py_spring_core.event.application_event_publisher import ApplicationEventPublisher +from py_spring_core.core.interfaces.application_context_required import \ + ApplicationContextRequired +from py_spring_core.event.application_event_handler_registry import \ + ApplicationEventHandlerRegistry +from py_spring_core.event.application_event_publisher import \ + ApplicationEventPublisher class PySpringApplication: @@ -81,8 +83,9 @@ def __init__( properties_path=self.app_config.properties_file_path ) self.fastapi = FastAPI() - self.app_context = ApplicationContext(config=self.app_context_config, server=self.fastapi) - + self.app_context = ApplicationContext( + config=self.app_context_config, server=self.fastapi + ) self.classes_with_handlers: dict[ Type[AppEntities], Callable[[Type[Any]], None] @@ -101,7 +104,7 @@ def __configure_logging(self): config = self.app_config.loguru_config if not config.log_file_path: return - + # Use the format field from config which contains the actual format string logger.add( config.log_file_path, @@ -113,10 +116,7 @@ def __configure_logging(self): self.__configure_uvicorn_logging() def _get_system_managed_classes(self) -> Iterable[Type[Component]]: - return [ - ApplicationEventPublisher, - ApplicationEventHandlerRegistry - ] + return [ApplicationEventPublisher, ApplicationEventHandlerRegistry] def _scan_classes_for_project(self) -> Iterable[Type[object]]: self.app_class_scanner.scan_classes_for_file_paths() @@ -129,7 +129,9 @@ def _register_app_entities(self, classes: Iterable[Type[object]]) -> None: continue handler(_cls) - def _get_all_entities_from_entity_providers(self, entity_providers: Iterable[EntityProvider]) -> Iterable[Type[AppEntities]]: + def _get_all_entities_from_entity_providers( + self, entity_providers: Iterable[EntityProvider] + ) -> Iterable[Type[AppEntities]]: entities: list[Type[AppEntities]] = [] for provider in entity_providers: entities.extend(provider.get_entities()) @@ -165,7 +167,9 @@ def _init_providers(self, providers: Iterable[EntityProvider]) -> None: for provider in providers: provider.provider_init() - def _inject_application_context_to_context_required(self, classes: Iterable[Type[object]]) -> None: + def _inject_application_context_to_context_required( + self, classes: Iterable[Type[object]] + ) -> None: for cls in classes: if not issubclass(cls, ApplicationContextRequired): continue @@ -174,10 +178,17 @@ def _inject_application_context_to_context_required(self, classes: Iterable[Type def _prepare_injected_classes(self) -> Iterable[Type[object]]: scanned_classes = self._scan_classes_for_project() system_managed_classes = self._get_system_managed_classes() - provider_entities = self._get_all_entities_from_entity_providers(self.entity_providers) + provider_entities = self._get_all_entities_from_entity_providers( + self.entity_providers + ) provider_classes = [provider.__class__ for provider in self.entity_providers] # providers typically requires app context, so add to classess to inject - classes_to_inject = [*scanned_classes, *system_managed_classes, *provider_entities, *provider_classes] + classes_to_inject = [ + *scanned_classes, + *system_managed_classes, + *provider_entities, + *provider_classes, + ] return classes_to_inject def __init_app(self) -> None: @@ -208,32 +219,43 @@ def __init_controllers(self) -> None: controllers = self.app_context.get_controller_instances() for controller in controllers: name = controller.__class__.__name__ - routes = RouteMapping.routes.get(name, set()) + routes = RouteMapping.routes.get(name, set()) controller.post_construct() controller._register_decorated_routes(routes) router = controller.get_router() self.fastapi.include_router(router) self.__init_middlewares() logger.debug(f"[CONTROLLER INIT] Controller {name} initialized") + def __init_middlewares(self) -> None: logger.debug("[MIDDLEWARE INIT] Initialize middlewares...") self_defined_registry_cls = MiddlewareRegistry.get_subclass() if self_defined_registry_cls is None: logger.debug("[MIDDLEWARE INIT] No self defined registry class found") return - logger.debug(f"[MIDDLEWARE INIT] Self defined registry class: {self_defined_registry_cls.__name__}") - logger.debug(f"[MIDDLEWARE INIT] Inject dependencies for external object: {self_defined_registry_cls.__name__}") - self.app_context.inject_dependencies_for_external_object(self_defined_registry_cls) + logger.debug( + f"[MIDDLEWARE INIT] Self defined registry class: {self_defined_registry_cls.__name__}" + ) + logger.debug( + f"[MIDDLEWARE INIT] Inject dependencies for external object: {self_defined_registry_cls.__name__}" + ) + self.app_context.inject_dependencies_for_external_object( + self_defined_registry_cls + ) registry = self_defined_registry_cls() middleware_classes = registry.get_middleware_classes() for middleware_class in middleware_classes: - logger.debug(f"[MIDDLEWARE INIT] Inject dependencies for middleware: {middleware_class.__name__}") + logger.debug( + f"[MIDDLEWARE INIT] Inject dependencies for middleware: {middleware_class.__name__}" + ) self.app_context.inject_dependencies_for_external_object(middleware_class) registry.apply_middlewares(self.fastapi) logger.debug("[MIDDLEWARE INIT] Middlewares initialized") + def __configure_uvicorn_logging(self): """Configure Uvicorn to use Loguru instead of default logging.""" + # Configure Uvicorn to use Loguru # Intercept standard logging and redirect to loguru class InterceptHandler(logging.Handler): @@ -250,7 +272,9 @@ def emit(self, record): frame = frame.f_back depth += 1 - logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) # Remove default uvicorn logger and add intercept handler log_level = self.app_config.loguru_config.log_level.value diff --git a/py_spring_core/core/entities/controllers/rest_controller.py b/py_spring_core/core/entities/controllers/rest_controller.py index 93e8933..2ff3a33 100644 --- a/py_spring_core/core/entities/controllers/rest_controller.py +++ b/py_spring_core/core/entities/controllers/rest_controller.py @@ -1,8 +1,10 @@ +from functools import partial from typing import Iterable + from fastapi import APIRouter, FastAPI -from functools import partial -from py_spring_core.core.entities.controllers.route_mapping import RouteRegistration +from py_spring_core.core.entities.controllers.route_mapping import \ + RouteRegistration from py_spring_core.core.entities.middlewares.middleware import Middleware diff --git a/py_spring_core/core/entities/entity_provider.py b/py_spring_core/core/entities/entity_provider.py index 79c21a8..d20c671 100644 --- a/py_spring_core/core/entities/entity_provider.py +++ b/py_spring_core/core/entities/entity_provider.py @@ -4,13 +4,13 @@ from py_spring_core.core.application.commons import AppEntities from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component -from py_spring_core.core.entities.controllers.rest_controller import RestController +from py_spring_core.core.entities.controllers.rest_controller import \ + RestController from py_spring_core.core.entities.properties.properties import Properties try: - from py_spring_core.core.application.context.application_context import ( - ApplicationContext, - ) + from py_spring_core.core.application.context.application_context import \ + ApplicationContext except ImportError: ... diff --git a/py_spring_core/core/entities/middlewares/middleware.py b/py_spring_core/core/entities/middlewares/middleware.py index 7eab8ac..dd89e70 100644 --- a/py_spring_core/core/entities/middlewares/middleware.py +++ b/py_spring_core/core/entities/middlewares/middleware.py @@ -1,40 +1,42 @@ from abc import abstractmethod from typing import Awaitable, Callable + from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware - class Middleware(BaseHTTPMiddleware): """ Middleware base class, inherits from FastAPI's BaseHTTPMiddleware Simpler to use, only need to implement the process_request method """ - + @abstractmethod async def process_request(self, request: Request) -> Response | None: """ Method to process requests - + Args: request: FastAPI request object - + Returns: Response | None: If Response is returned, it will be directly returned to the client If None is returned, continue to execute the next middleware or route handler """ pass - - async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + + async def dispatch( + self, request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: """ Middleware dispatch method, automatically called by FastAPI """ # First execute custom request processing logic response = await self.process_request(request) - + # If a response is returned, return it directly if response is not None: return response - + # Otherwise continue to execute the next middleware or route handler - return await call_next(request) \ No newline at end of file + return await call_next(request) diff --git a/py_spring_core/core/entities/middlewares/middleware_registry.py b/py_spring_core/core/entities/middlewares/middleware_registry.py index ecff80c..4829249 100644 --- a/py_spring_core/core/entities/middlewares/middleware_registry.py +++ b/py_spring_core/core/entities/middlewares/middleware_registry.py @@ -1,57 +1,56 @@ - - from abc import ABC, abstractmethod from typing import Type -from fastapi import FastAPI -from py_spring_core.core.entities.middlewares.middleware import Middleware -from py_spring_core.core.interfaces.single_inheritance_required import SingleInheritanceRequired +from fastapi import FastAPI +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.interfaces.single_inheritance_required import \ + SingleInheritanceRequired class MiddlewareRegistry(SingleInheritanceRequired["MiddlewareRegistry"], ABC): """ Middleware registry for managing all middlewares - - This registry pattern eliminates the need for manual middleware registration. + + This registry pattern eliminates the need for manual middleware registration. The framework automatically handles middleware registration and execution order. - + Multiple middleware execution order: - When multiple middlewares are registered through this registry, they are automatically + When multiple middlewares are registered through this registry, they are automatically applied to the FastAPI application in the order they are returned by get_middleware_classes(). - Each middleware wraps the application, forming a stack. The last middleware added is the outermost, + Each middleware wraps the application, forming a stack. The last middleware added is the outermost, and the first is the innermost. - + On the request path, the outermost middleware runs first. On the response path, it runs last. - + For example, if get_middleware_classes() returns [MiddlewareA, MiddlewareB] This results in the following execution order: Request: MiddlewareB → MiddlewareA → route Response: route → MiddlewareA → MiddlewareB This stacking behavior ensures that middlewares are executed in a predictable and controllable order. """ - + @abstractmethod def get_middleware_classes(self) -> list[Type[Middleware]]: """ Get all registered middleware classes - + Returns: List[Type[Middleware]]: List of middleware classes """ pass - + def apply_middlewares(self, app: FastAPI) -> FastAPI: """ Apply middlewares to FastAPI application - + Args: app: FastAPI application instance - + Returns: FastAPI: FastAPI instance with applied middlewares """ for middleware_class in self.get_middleware_classes(): app.add_middleware(middleware_class) - return app \ No newline at end of file + return app diff --git a/py_spring_core/core/interfaces/application_context_required.py b/py_spring_core/core/interfaces/application_context_required.py index ea91abd..24ac219 100644 --- a/py_spring_core/core/interfaces/application_context_required.py +++ b/py_spring_core/core/interfaces/application_context_required.py @@ -1,27 +1,28 @@ from typing import Optional -from py_spring_core.core.application.context.application_context import ApplicationContext +from py_spring_core.core.application.context.application_context import \ + ApplicationContext class ApplicationContextRequired: """ A mixin class that provides access to the ApplicationContext for classes that need it. - + This class serves as a base for components that require access to the ApplicationContext. It provides class-level methods to set and retrieve the ApplicationContext instance. - + Usage: class MyComponent(ApplicationContextRequired): def some_method(self): context = self.get_application_context() # Use the context... - + Note: The ApplicationContext must be set before attempting to retrieve it, otherwise a RuntimeError will be raised. """ - _app_context: Optional[ApplicationContext] = None + _app_context: Optional[ApplicationContext] = None @classmethod def set_application_context(cls, application_context: ApplicationContext) -> None: @@ -31,4 +32,4 @@ def set_application_context(cls, application_context: ApplicationContext) -> Non def get_application_context(cls) -> ApplicationContext: if cls._app_context is None: raise RuntimeError("ApplicationContext is not set") - return cls._app_context \ No newline at end of file + return cls._app_context diff --git a/py_spring_core/core/interfaces/single_inheritance_required.py b/py_spring_core/core/interfaces/single_inheritance_required.py index 03150be..211a39b 100644 --- a/py_spring_core/core/interfaces/single_inheritance_required.py +++ b/py_spring_core/core/interfaces/single_inheritance_required.py @@ -1,9 +1,8 @@ - - from abc import ABC from typing import Generic, Optional, Type, TypeVar, cast -T = TypeVar('T') +T = TypeVar("T") + class SingleInheritanceRequired(Generic[T], ABC): """ @@ -21,8 +20,10 @@ def check_only_one_subclass_allowed(cls) -> None: continue class_dict[subclass.__name__] = subclass if len(class_dict) > 1: - raise ValueError(f"Only one subclass is allowed for {cls.__name__}, but {len(class_dict)} subclasses: {[subclass.__name__ for subclass in class_dict.values()]} found") - + raise ValueError( + f"Only one subclass is allowed for {cls.__name__}, but {len(class_dict)} subclasses: {[subclass.__name__ for subclass in class_dict.values()]} found" + ) + @classmethod def get_subclass(cls) -> Optional[Type[T]]: """ @@ -31,4 +32,4 @@ def get_subclass(cls) -> Optional[Type[T]]: cls.check_only_one_subclass_allowed() if len(cls.__subclasses__()) == 0: return - return cast(Type[T], cls.__subclasses__()[0]) \ No newline at end of file + return cast(Type[T], cls.__subclasses__()[0]) diff --git a/py_spring_core/core/utils.py b/py_spring_core/core/utils.py index 6a26683..0f549b7 100644 --- a/py_spring_core/core/utils.py +++ b/py_spring_core/core/utils.py @@ -1,6 +1,6 @@ -from abc import ABC import importlib.util import inspect +from abc import ABC from pathlib import Path from typing import Any, Iterable, Type @@ -83,7 +83,7 @@ def get_unimplemented_abstract_methods(cls: Type[Any]) -> set[str]: Returns a set of abstract method names not implemented in the given class. Args: cls (Type[Any]): A subclass of abc.ABC - + Returns: set[str]: A set of method names that are abstract but not yet implemented """ @@ -95,12 +95,14 @@ def get_unimplemented_abstract_methods(cls: Type[Any]) -> set[str]: abstract_methods: set[str] = set() for base in cls.__mro__: - base_abstracts = getattr(base, '__abstractmethods__', set()) + base_abstracts = getattr(base, "__abstractmethods__", set()) abstract_methods = abstract_methods.union(base_abstracts) implemented_methods: set[str] = { - attr for attr in dir(cls) - if callable(getattr(cls, attr)) and not getattr(getattr(cls, attr), '__isabstractmethod__', False) + attr + for attr in dir(cls) + if callable(getattr(cls, attr)) + and not getattr(getattr(cls, attr), "__isabstractmethod__", False) } return abstract_methods.difference(implemented_methods) diff --git a/py_spring_core/event/application_event_handler_registry.py b/py_spring_core/event/application_event_handler_registry.py index d2c2863..e618434 100644 --- a/py_spring_core/event/application_event_handler_registry.py +++ b/py_spring_core/event/application_event_handler_registry.py @@ -1,4 +1,3 @@ - from threading import Thread from typing import Callable, ClassVar, Type @@ -6,21 +5,25 @@ from pydantic import BaseModel from py_spring_core.core.entities.component import Component -from py_spring_core.core.interfaces.application_context_required import ApplicationContextRequired +from py_spring_core.core.interfaces.application_context_required import \ + ApplicationContextRequired from py_spring_core.event.commons import ApplicationEvent, EventQueue EventHandlerT = Callable[[Component, ApplicationEvent], None] + def EventListener(event_type: Type[ApplicationEvent]) -> Callable: """ The EventListener decorator is used to register an event handler for an application event. It is responsible for binding an event handler to a component and a function. """ + def decorator(func: EventHandlerT) -> None: if not issubclass(event_type, ApplicationEvent): raise ValueError(f"Event type must be a subclass of ApplicationEvent") ApplicationEventHandlerRegistry.register_event_handler(event_type, func) + return decorator @@ -29,6 +32,7 @@ class EventHandler(BaseModel): The EventHandler class is a model that represents an event handler for an application event. It is responsible for binding an event handler to a component and a function. """ + class_name: str func_name: str event_type: Type[ApplicationEvent] @@ -38,7 +42,7 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, EventHandler): return False return self.class_name == other.class_name and self.func_name == other.func_name - + def __hash__(self) -> int: return hash((self.class_name, self.func_name)) @@ -52,7 +56,9 @@ class ApplicationEventHandlerRegistry(Component, ApplicationContextRequired): - Registers event handlers for application events - Binds event handlers to their corresponding components """ + _class_event_handlers: ClassVar[dict[str, list[EventHandler]]] = {} + def __init__(self) -> None: self._event_handlers: dict[str, list[EventHandler]] = {} self._event_message_queue = EventQueue.queue @@ -61,20 +67,21 @@ def post_construct(self) -> None: logger.info("Initializing event handlers...") self._init_event_handlers() logger.info("Starting event message handler thread...") - Thread(target= self._handle_messages, daemon=True).start() + Thread(target=self._handle_messages, daemon=True).start() def _init_event_handlers(self) -> None: app_context = self.get_application_context() # get_name might be different from the class name, so we use the class name for function binding self.component_instance_map = { - component.__class__.__name__: component + component.__class__.__name__: component for component in app_context.get_singleton_component_instances() } self._event_handlers = self._class_event_handlers - @classmethod - def register_event_handler(cls, event_type: Type[ApplicationEvent], handler: EventHandlerT): + def register_event_handler( + cls, event_type: Type[ApplicationEvent], handler: EventHandlerT + ): event_name = event_type.__name__ func_name_parts = handler.__qualname__.split(".") if len(func_name_parts) != 2: @@ -82,26 +89,36 @@ def register_event_handler(cls, event_type: Type[ApplicationEvent], handler: Eve class_name, func_name = func_name_parts if event_name not in cls._class_event_handlers: cls._class_event_handlers[event_name] = [] - event_handler = EventHandler(class_name=class_name, func_name=func_name, event_type=event_type, func=handler) + event_handler = EventHandler( + class_name=class_name, + func_name=func_name, + event_type=event_type, + func=handler, + ) if event_handler not in cls._class_event_handlers[event_name]: cls._class_event_handlers[event_name].append(event_handler) - def get_event_handlers(self, event_type: Type[ApplicationEvent]) -> list[EventHandler]: + def get_event_handlers( + self, event_type: Type[ApplicationEvent] + ) -> list[EventHandler]: event_name = event_type.__name__ handlers = self._event_handlers.get(event_name, []) return handlers - + def _handle_messages(self) -> None: logger.info("Event message handler thread started...") while True: message = self._event_message_queue.get() for handler in self.get_event_handlers(message.__class__): try: - optional_instance = self.component_instance_map.get(handler.class_name, None) + optional_instance = self.component_instance_map.get( + handler.class_name, None + ) if optional_instance is None: - logger.error(f"Component instance not found for handler: {handler.class_name}") + logger.error( + f"Component instance not found for handler: {handler.class_name}" + ) continue handler.func(optional_instance, message) except Exception as error: logger.error(f"Error handling event: {error}") - \ No newline at end of file diff --git a/py_spring_core/event/application_event_publisher.py b/py_spring_core/event/application_event_publisher.py index efc3115..4383d1d 100644 --- a/py_spring_core/event/application_event_publisher.py +++ b/py_spring_core/event/application_event_publisher.py @@ -1,15 +1,13 @@ from typing import TypeVar - from py_spring_core.core.entities.component import Component -from py_spring_core.event.application_event_handler_registry import ApplicationEvent, ApplicationEventHandlerRegistry +from py_spring_core.event.application_event_handler_registry import ( + ApplicationEvent, ApplicationEventHandlerRegistry) from py_spring_core.event.commons import EventQueue T = TypeVar("T", bound=ApplicationEvent) - - class ApplicationEventPublisher(Component): """ The ApplicationEventPublisher is a component that publishes application events. @@ -18,13 +16,9 @@ class ApplicationEventPublisher(Component): The class performs the following key tasks: - Publishes application events to the event message queue """ + def __init__(self): self.event_message_queue = EventQueue.queue - + def publish(self, event: ApplicationEvent) -> None: self.event_message_queue.put(event) - - - - - \ No newline at end of file diff --git a/py_spring_core/event/commons.py b/py_spring_core/event/commons.py index bd76bc4..3ba14a6 100644 --- a/py_spring_core/event/commons.py +++ b/py_spring_core/event/commons.py @@ -2,6 +2,9 @@ from pydantic import BaseModel + class ApplicationEvent(BaseModel): ... + + class EventQueue: - queue: Queue[ApplicationEvent] = Queue() \ No newline at end of file + queue: Queue[ApplicationEvent] = Queue() diff --git a/tests/test_application_context.py b/tests/test_application_context.py index f306daa..cd1a90d 100644 --- a/tests/test_application_context.py +++ b/tests/test_application_context.py @@ -1,13 +1,12 @@ -from fastapi import FastAPI import pytest +from fastapi import FastAPI from py_spring_core.core.application.context.application_context import ( - ApplicationContext, - ApplicationContextConfig, -) + ApplicationContext, ApplicationContextConfig) from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component -from py_spring_core.core.entities.controllers.rest_controller import RestController +from py_spring_core.core.entities.controllers.rest_controller import \ + RestController from py_spring_core.core.entities.properties.properties import Properties @@ -38,20 +37,27 @@ class TestProperties(Properties): assert ( "TestComponent" in app_context.container_manager.component_cls_container - and app_context.container_manager.component_cls_container["TestComponent"] == TestComponent + and app_context.container_manager.component_cls_container["TestComponent"] + == TestComponent ) assert ( "TestController" in app_context.container_manager.controller_cls_container - and app_context.container_manager.controller_cls_container["TestController"] == TestController + and app_context.container_manager.controller_cls_container["TestController"] + == TestController ) assert ( - "TestBeanCollection" in app_context.container_manager.bean_collection_cls_container - and app_context.container_manager.bean_collection_cls_container["TestBeanCollection"] + "TestBeanCollection" + in app_context.container_manager.bean_collection_cls_container + and app_context.container_manager.bean_collection_cls_container[ + "TestBeanCollection" + ] == TestBeanCollection ) assert ( "test_properties" in app_context.container_manager.properties_cls_container - and app_context.container_manager.properties_cls_container["test_properties"] + and app_context.container_manager.properties_cls_container[ + "test_properties" + ] == TestProperties ) @@ -71,9 +77,16 @@ class TestProperties(Properties): app_context.register_properties(TestProperties) assert "TestComponent" in app_context.container_manager.component_cls_container - assert "TestController" in app_context.container_manager.controller_cls_container - assert "TestBeanCollection" in app_context.container_manager.bean_collection_cls_container - assert "test_properties" in app_context.container_manager.properties_cls_container + assert ( + "TestController" in app_context.container_manager.controller_cls_container + ) + assert ( + "TestBeanCollection" + in app_context.container_manager.bean_collection_cls_container + ) + assert ( + "test_properties" in app_context.container_manager.properties_cls_container + ) def test_register_invalid_entities_raises_error( self, app_context: ApplicationContext @@ -128,25 +141,25 @@ class TestProperties(Properties): # Test retrieving singleton components component_instance = TestComponent() - app_context.container_manager.singleton_component_instance_container["TestComponent"] = ( - component_instance - ) + app_context.container_manager.singleton_component_instance_container[ + "TestComponent" + ] = component_instance retrieved_component = app_context.get_component(TestComponent, None) assert retrieved_component is component_instance # Test retrieving singleton beans bean_instance = TestBeanCollection() - app_context.container_manager.singleton_bean_instance_container["TestBeanCollection"] = ( - bean_instance - ) + app_context.container_manager.singleton_bean_instance_container[ + "TestBeanCollection" + ] = bean_instance retrieved_bean = app_context.get_bean(TestBeanCollection, None) assert retrieved_bean is bean_instance # Test retrieving singleton properties properties_instance = TestProperties() - app_context.container_manager.singleton_properties_instance_container["test_properties"] = ( - properties_instance - ) + app_context.container_manager.singleton_properties_instance_container[ + "test_properties" + ] = properties_instance retrieved_properties = app_context.get_properties(TestProperties) assert retrieved_properties is properties_instance diff --git a/tests/test_bean_collection.py b/tests/test_bean_collection.py index 6a35c21..3c73fbe 100644 --- a/tests/test_bean_collection.py +++ b/tests/test_bean_collection.py @@ -1,6 +1,7 @@ import pytest -from py_spring_core.core.entities.bean_collection import BeanCollection, BeanView +from py_spring_core.core.entities.bean_collection import (BeanCollection, + BeanView) from py_spring_core.core.entities.component import Component diff --git a/tests/test_component_features.py b/tests/test_component_features.py index 9928e9d..583110c 100644 --- a/tests/test_component_features.py +++ b/tests/test_component_features.py @@ -1,13 +1,11 @@ from abc import ABC from typing import Annotated -from fastapi import FastAPI import pytest +from fastapi import FastAPI from py_spring_core.core.application.context.application_context import ( - ApplicationContext, - ApplicationContextConfig, -) + ApplicationContext, ApplicationContextConfig) from py_spring_core.core.entities.component import Component, ComponentScope @@ -111,11 +109,11 @@ def process(self) -> str: # Register component first time app_context.register_component(TestService) initial_count = len(app_context.container_manager.component_cls_container) - + # Register same component again - should be silently skipped app_context.register_component(TestService) final_count = len(app_context.container_manager.component_cls_container) - + # Verify component count didn't change (no duplicate registration) assert final_count == initial_count assert "TestService" in app_context.container_manager.component_cls_container @@ -147,8 +145,13 @@ def process(self) -> str: app_context.init_ioc_container() # Verify component is registered with custom name - assert "CustomServiceName" in app_context.container_manager.component_cls_container - assert app_context.container_manager.component_cls_container["CustomServiceName"] == TestService + assert ( + "CustomServiceName" in app_context.container_manager.component_cls_container + ) + assert ( + app_context.container_manager.component_cls_container["CustomServiceName"] + == TestService + ) def test_qualifier_with_invalid_component(self, app_context: ApplicationContext): """ diff --git a/tests/test_entity_provider.py b/tests/test_entity_provider.py index e4a2a58..e632ed4 100644 --- a/tests/test_entity_provider.py +++ b/tests/test_entity_provider.py @@ -1,13 +1,10 @@ -from fastapi import FastAPI import pytest +from fastapi import FastAPI from py_spring_core.core.application.context.application_context import ( - ApplicationContext, - InvalidDependencyError, -) -from py_spring_core.core.application.context.application_context_config import ( - ApplicationContextConfig, -) + ApplicationContext, InvalidDependencyError) +from py_spring_core.core.application.context.application_context_config import \ + ApplicationContextConfig from py_spring_core.core.entities.component import Component from py_spring_core.core.entities.entity_provider import EntityProvider @@ -19,7 +16,7 @@ class TestEntityProvider: @pytest.fixture def test_entity_provider(self): return EntityProvider(depends_on=[TestComponent]) - + @pytest.fixture def server(self) -> FastAPI: return FastAPI() @@ -28,7 +25,9 @@ def server(self) -> FastAPI: def test_app_context( self, test_entity_provider: EntityProvider, server: FastAPI ) -> ApplicationContext: - app_context = ApplicationContext(ApplicationContextConfig(properties_path=""), server=server) + app_context = ApplicationContext( + ApplicationContextConfig(properties_path=""), server=server + ) app_context.providers.append(test_entity_provider) return app_context diff --git a/tests/test_middleware.py b/tests/test_middleware.py index bb5f368..08e6156 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,11 +1,13 @@ +from unittest.mock import AsyncMock, Mock, patch + import pytest -from unittest.mock import Mock, AsyncMock, patch from fastapi import FastAPI, Request, Response from fastapi.testclient import TestClient from starlette.middleware.base import BaseHTTPMiddleware from py_spring_core.core.entities.middlewares.middleware import Middleware -from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry +from py_spring_core.core.entities.middlewares.middleware_registry import \ + MiddlewareRegistry class TestMiddleware: @@ -27,7 +29,7 @@ def mock_call_next(self): def test_middleware_inherits_from_base_http_middleware(self): """ Test that Middleware class inherits from BaseHTTPMiddleware. - + This test verifies that: 1. Middleware is a subclass of BaseHTTPMiddleware 2. The inheritance relationship is correctly established @@ -37,37 +39,40 @@ def test_middleware_inherits_from_base_http_middleware(self): def test_middleware_is_abstract(self): """ Test that Middleware class is abstract and cannot be instantiated directly. - + This test verifies that: 1. Middleware is an abstract base class 2. Attempting to instantiate it directly raises an error """ # Test that Middleware is abstract by checking it has abstract methods - assert hasattr(Middleware, 'process_request') + assert hasattr(Middleware, "process_request") assert Middleware.process_request.__isabstractmethod__ def test_process_request_is_abstract(self): """ Test that process_request method is abstract and must be implemented. - + This test verifies that: 1. process_request is an abstract method 2. Subclasses must implement this method """ + # Create a concrete subclass without implementing process_request class ConcreteMiddleware(Middleware): pass # Test that the class is abstract by checking it has abstract methods - assert hasattr(ConcreteMiddleware, 'process_request') + assert hasattr(ConcreteMiddleware, "process_request") # The method should still be abstract since it wasn't implemented assert ConcreteMiddleware.process_request.__isabstractmethod__ @pytest.mark.asyncio - async def test_dispatch_continues_when_process_request_returns_none(self, mock_request, mock_call_next): + async def test_dispatch_continues_when_process_request_returns_none( + self, mock_request, mock_call_next + ): """ Test that dispatch continues to next middleware when process_request returns None. - + This test verifies that: 1. When process_request returns None, dispatch continues to call_next 2. The call_next function is called with the correct request @@ -87,10 +92,12 @@ async def process_request(self, request: Request) -> Response | None: assert result == expected_response @pytest.mark.asyncio - async def test_dispatch_returns_response_when_process_request_returns_response(self, mock_request, mock_call_next): + async def test_dispatch_returns_response_when_process_request_returns_response( + self, mock_request, mock_call_next + ): """ Test that dispatch returns response directly when process_request returns a response. - + This test verifies that: 1. When process_request returns a Response, dispatch returns it directly 2. call_next is not called when process_request returns a response @@ -109,10 +116,12 @@ async def process_request(self, request: Request) -> Response | None: assert result == middleware_response @pytest.mark.asyncio - async def test_dispatch_passes_request_to_process_request(self, mock_request, mock_call_next): + async def test_dispatch_passes_request_to_process_request( + self, mock_request, mock_call_next + ): """ Test that dispatch passes the request to process_request method. - + This test verifies that: 1. The request object is correctly passed to process_request 2. The process_request method receives the exact same request object @@ -142,7 +151,7 @@ def fastapi_app(self): def test_middleware_registry_is_abstract(self): """ Test that MiddlewareRegistry class is abstract and cannot be instantiated directly. - + This test verifies that: 1. MiddlewareRegistry is an abstract base class 2. Attempting to instantiate it directly raises an error @@ -150,17 +159,18 @@ def test_middleware_registry_is_abstract(self): # This test verifies that MiddlewareRegistry is abstract # We can't test direct instantiation because it's abstract # Instead, we test that it has the abstract method - assert hasattr(MiddlewareRegistry, 'get_middleware_classes') + assert hasattr(MiddlewareRegistry, "get_middleware_classes") assert MiddlewareRegistry.get_middleware_classes.__isabstractmethod__ def test_get_middleware_classes_is_abstract(self): """ Test that get_middleware_classes method is abstract and must be implemented. - + This test verifies that: 1. get_middleware_classes is an abstract method 2. Subclasses must implement this method """ + # Create a concrete subclass without implementing get_middleware_classes class ConcreteRegistry(MiddlewareRegistry): # type: ignore[abstract] pass @@ -171,12 +181,13 @@ class ConcreteRegistry(MiddlewareRegistry): # type: ignore[abstract] def test_apply_middlewares_adds_middleware_to_app(self, fastapi_app): """ Test that apply_middlewares correctly adds middleware classes to FastAPI app. - + This test verifies that: 1. Middleware classes are added to the FastAPI application 2. The add_middleware method is called for each middleware class 3. The app is returned unchanged """ + class TestMiddleware1(Middleware): async def process_request(self, request: Request) -> Response | None: return None @@ -190,7 +201,7 @@ def get_middleware_classes(self) -> list[type[Middleware]]: return [TestMiddleware1, TestMiddleware2] # Mock the add_middleware method - with patch.object(fastapi_app, 'add_middleware') as mock_add_middleware: + with patch.object(fastapi_app, "add_middleware") as mock_add_middleware: registry = TestRegistry() result = registry.apply_middlewares(fastapi_app) @@ -198,41 +209,43 @@ def get_middleware_classes(self) -> list[type[Middleware]]: assert mock_add_middleware.call_count == 2 mock_add_middleware.assert_any_call(TestMiddleware1) mock_add_middleware.assert_any_call(TestMiddleware2) - + # Verify the app is returned assert result == fastapi_app def test_apply_middlewares_with_empty_list(self, fastapi_app): """ Test that apply_middlewares handles empty middleware list correctly. - + This test verifies that: 1. When no middlewares are registered, no middleware is added 2. The app is returned unchanged 3. No errors occur with empty middleware list """ + class EmptyRegistry(MiddlewareRegistry): def get_middleware_classes(self) -> list[type[Middleware]]: return [] - with patch.object(fastapi_app, 'add_middleware') as mock_add_middleware: + with patch.object(fastapi_app, "add_middleware") as mock_add_middleware: registry = EmptyRegistry() result = registry.apply_middlewares(fastapi_app) # Verify add_middleware was not called mock_add_middleware.assert_not_called() - + # Verify the app is returned assert result == fastapi_app def test_apply_middlewares_preserves_app_state(self, fastapi_app): """ Test that apply_middlewares preserves the FastAPI app state. - + This test verifies that: 1. The original app object is returned (same reference) 2. No app properties are modified during middleware application """ + class TestMiddleware(Middleware): async def process_request(self, request: Request) -> Response | None: return None @@ -243,7 +256,7 @@ def get_middleware_classes(self) -> list[type[Middleware]]: # Store original app state original_app_id = id(fastapi_app) - + registry = TestRegistry() result = registry.apply_middlewares(fastapi_app) @@ -264,7 +277,7 @@ def fastapi_app(self): async def test_middleware_chain_execution(self, fastapi_app): """ Test that multiple middlewares execute in the correct order. - + This test verifies that: 1. Middlewares are executed in the order they are added 2. Each middleware can process the request @@ -291,7 +304,7 @@ def get_middleware_classes(self) -> list[type[Middleware]]: # Create a test client to trigger middleware execution from fastapi.testclient import TestClient - + @app.get("/test") async def test_endpoint(): return {"message": "test"} @@ -307,7 +320,7 @@ async def test_endpoint(): async def test_middleware_early_return(self, fastapi_app): """ Test that middleware can return early and prevent further execution. - + This test verifies that: 1. When a middleware returns a response, subsequent middlewares are not executed 2. The route handler is not called when middleware returns early @@ -349,14 +362,14 @@ async def test_endpoint(): def test_middleware_registry_single_inheritance(self): """ Test that MiddlewareRegistry enforces single inheritance. - + This test verifies that: 1. MiddlewareRegistry implements SingleInheritanceRequired 2. Multiple inheritance is prevented """ # This test assumes SingleInheritanceRequired prevents multiple inheritance # The actual behavior depends on the implementation of SingleInheritanceRequired - + class TestRegistry(MiddlewareRegistry): def get_middleware_classes(self) -> list[type[Middleware]]: return [] @@ -368,11 +381,12 @@ def get_middleware_classes(self) -> list[type[Middleware]]: def test_middleware_type_hints(self): """ Test that middleware classes have correct type hints. - + This test verifies that: 1. get_middleware_classes returns the correct type 2. process_request has correct parameter and return type hints """ + class TestMiddleware(Middleware): async def process_request(self, request: Request) -> Response | None: return None @@ -386,4 +400,7 @@ def get_middleware_classes(self) -> list[type[Middleware]]: # Verify type hints assert isinstance(middleware_classes, list) - assert all(issubclass(middleware_class, Middleware) for middleware_class in middleware_classes) \ No newline at end of file + assert all( + issubclass(middleware_class, Middleware) + for middleware_class in middleware_classes + ) diff --git a/tests/test_properties_loader.py b/tests/test_properties_loader.py index db0e606..f79c970 100644 --- a/tests/test_properties_loader.py +++ b/tests/test_properties_loader.py @@ -3,9 +3,7 @@ from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.properties.properties_loader import ( - InvalidPropertiesKeyError, - _PropertiesLoader, -) + InvalidPropertiesKeyError, _PropertiesLoader) class TestPropertiesLoader: From 2f33a64488cffd734a0802aa155a314c04bf737c Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 18 Jul 2025 16:21:15 +0800 Subject: [PATCH 28/42] [skip ci] ruff formatting --- py_spring_core/__init__.py | 29 ++++++------ .../config_file_template_generator.py | 7 +-- py_spring_core/core/application/commons.py | 3 +- .../context/application_context.py | 35 ++++++++++----- .../core/application/py_spring_application.py | 44 ++++++++++--------- .../entities/controllers/rest_controller.py | 3 +- .../core/entities/entity_provider.py | 8 ++-- .../middlewares/middleware_registry.py | 5 ++- .../application_context_required.py | 5 ++- .../application_event_handler_registry.py | 5 ++- .../event/application_event_publisher.py | 4 +- tests/test_application_context.py | 7 +-- tests/test_bean_collection.py | 3 +- tests/test_component_features.py | 4 +- tests/test_entity_provider.py | 9 ++-- tests/test_middleware.py | 5 ++- tests/test_properties_loader.py | 4 +- 17 files changed, 105 insertions(+), 75 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 903c6f7..6c7775d 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -1,22 +1,25 @@ -from py_spring_core.core.application.py_spring_application import \ - PySpringApplication +from py_spring_core.core.application.py_spring_application import PySpringApplication from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component, ComponentScope -from py_spring_core.core.entities.controllers.rest_controller import \ - RestController +from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.controllers.route_mapping import ( - DeleteMapping, GetMapping, PatchMapping, PostMapping, PutMapping) + DeleteMapping, + GetMapping, + PatchMapping, + PostMapping, + PutMapping, +) from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.core.entities.middlewares.middleware import Middleware -from py_spring_core.core.entities.middlewares.middleware_registry import \ - MiddlewareRegistry +from py_spring_core.core.entities.middlewares.middleware_registry import ( + MiddlewareRegistry, +) from py_spring_core.core.entities.properties.properties import Properties -from py_spring_core.core.interfaces.application_context_required import \ - ApplicationContextRequired -from py_spring_core.event.application_event_handler_registry import \ - EventListener -from py_spring_core.event.application_event_publisher import \ - ApplicationEventPublisher +from py_spring_core.core.interfaces.application_context_required import ( + ApplicationContextRequired, +) +from py_spring_core.event.application_event_handler_registry import EventListener +from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent __version__ = "0.0.19" diff --git a/py_spring_core/commons/config_file_template_generator/config_file_template_generator.py b/py_spring_core/commons/config_file_template_generator/config_file_template_generator.py index 96b85df..796c1b0 100644 --- a/py_spring_core/commons/config_file_template_generator/config_file_template_generator.py +++ b/py_spring_core/commons/config_file_template_generator/config_file_template_generator.py @@ -6,9 +6,10 @@ from pydantic import BaseModel from py_spring_core.commons.config_file_template_generator.templates import ( - app_config_template, app_properties_template) -from py_spring_core.core.application.application_config import \ - ApplicationConfig + app_config_template, + app_properties_template, +) +from py_spring_core.core.application.application_config import ApplicationConfig class ConfigFileTemplateGenerator: diff --git a/py_spring_core/core/application/commons.py b/py_spring_core/core/application/commons.py index 6b4e9a0..eb0b11b 100644 --- a/py_spring_core/core/application/commons.py +++ b/py_spring_core/core/application/commons.py @@ -1,7 +1,6 @@ from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component -from py_spring_core.core.entities.controllers.rest_controller import \ - RestController +from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.properties.properties import Properties AppEntities = Component | RestController | BeanCollection | Properties diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index 3d8b043..a51d541 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -1,7 +1,17 @@ from abc import ABC from inspect import isclass -from typing import (Annotated, Any, Callable, Mapping, Optional, Type, TypeVar, - cast, get_args, get_origin) +from typing import ( + Annotated, + Any, + Callable, + Mapping, + Optional, + Type, + TypeVar, + cast, + get_args, + get_origin, +) from fastapi import FastAPI from loguru import logger @@ -9,19 +19,20 @@ import py_spring_core.core.utils as framework_utils from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.application.context.application_context_config import \ - ApplicationContextConfig -from py_spring_core.core.entities.bean_collection import (BeanCollection, - BeanConflictError, - BeanView, - InvalidBeanError) +from py_spring_core.core.application.context.application_context_config import ( + ApplicationContextConfig, +) +from py_spring_core.core.entities.bean_collection import ( + BeanCollection, + BeanConflictError, + BeanView, + InvalidBeanError, +) from py_spring_core.core.entities.component import Component, ComponentScope -from py_spring_core.core.entities.controllers.rest_controller import \ - RestController +from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.core.entities.properties.properties import Properties -from py_spring_core.core.entities.properties.properties_loader import \ - _PropertiesLoader +from py_spring_core.core.entities.properties.properties_loader import _PropertiesLoader T = TypeVar("T", bound=AppEntities) PT = TypeVar("PT", bound=Properties) diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 4f6ba99..e3f8c42 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -7,34 +7,38 @@ from loguru import logger from py_spring_core.commons.class_scanner import ClassScanner -from py_spring_core.commons.config_file_template_generator.config_file_template_generator import \ - ConfigFileTemplateGenerator +from py_spring_core.commons.config_file_template_generator.config_file_template_generator import ( + ConfigFileTemplateGenerator, +) from py_spring_core.commons.file_path_scanner import FilePathScanner from py_spring_core.commons.type_checking_service import TypeCheckingService -from py_spring_core.core.application.application_config import \ - ApplicationConfigRepository +from py_spring_core.core.application.application_config import ( + ApplicationConfigRepository, +) from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.application.context.application_context import \ - ApplicationContext -from py_spring_core.core.application.context.application_context_config import \ - ApplicationContextConfig +from py_spring_core.core.application.context.application_context import ( + ApplicationContext, +) +from py_spring_core.core.application.context.application_context_config import ( + ApplicationContextConfig, +) from py_spring_core.core.application.loguru_config import LogFormat from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import (Component, - ComponentLifeCycle) -from py_spring_core.core.entities.controllers.rest_controller import \ - RestController +from py_spring_core.core.entities.component import Component, ComponentLifeCycle +from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.controllers.route_mapping import RouteMapping from py_spring_core.core.entities.entity_provider import EntityProvider -from py_spring_core.core.entities.middlewares.middleware_registry import \ - MiddlewareRegistry +from py_spring_core.core.entities.middlewares.middleware_registry import ( + MiddlewareRegistry, +) from py_spring_core.core.entities.properties.properties import Properties -from py_spring_core.core.interfaces.application_context_required import \ - ApplicationContextRequired -from py_spring_core.event.application_event_handler_registry import \ - ApplicationEventHandlerRegistry -from py_spring_core.event.application_event_publisher import \ - ApplicationEventPublisher +from py_spring_core.core.interfaces.application_context_required import ( + ApplicationContextRequired, +) +from py_spring_core.event.application_event_handler_registry import ( + ApplicationEventHandlerRegistry, +) +from py_spring_core.event.application_event_publisher import ApplicationEventPublisher class PySpringApplication: diff --git a/py_spring_core/core/entities/controllers/rest_controller.py b/py_spring_core/core/entities/controllers/rest_controller.py index 2ff3a33..93f6c6f 100644 --- a/py_spring_core/core/entities/controllers/rest_controller.py +++ b/py_spring_core/core/entities/controllers/rest_controller.py @@ -3,8 +3,7 @@ from fastapi import APIRouter, FastAPI -from py_spring_core.core.entities.controllers.route_mapping import \ - RouteRegistration +from py_spring_core.core.entities.controllers.route_mapping import RouteRegistration from py_spring_core.core.entities.middlewares.middleware import Middleware diff --git a/py_spring_core/core/entities/entity_provider.py b/py_spring_core/core/entities/entity_provider.py index d20c671..79c21a8 100644 --- a/py_spring_core/core/entities/entity_provider.py +++ b/py_spring_core/core/entities/entity_provider.py @@ -4,13 +4,13 @@ from py_spring_core.core.application.commons import AppEntities from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component -from py_spring_core.core.entities.controllers.rest_controller import \ - RestController +from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.properties.properties import Properties try: - from py_spring_core.core.application.context.application_context import \ - ApplicationContext + from py_spring_core.core.application.context.application_context import ( + ApplicationContext, + ) except ImportError: ... diff --git a/py_spring_core/core/entities/middlewares/middleware_registry.py b/py_spring_core/core/entities/middlewares/middleware_registry.py index 4829249..348ba00 100644 --- a/py_spring_core/core/entities/middlewares/middleware_registry.py +++ b/py_spring_core/core/entities/middlewares/middleware_registry.py @@ -4,8 +4,9 @@ from fastapi import FastAPI from py_spring_core.core.entities.middlewares.middleware import Middleware -from py_spring_core.core.interfaces.single_inheritance_required import \ - SingleInheritanceRequired +from py_spring_core.core.interfaces.single_inheritance_required import ( + SingleInheritanceRequired, +) class MiddlewareRegistry(SingleInheritanceRequired["MiddlewareRegistry"], ABC): diff --git a/py_spring_core/core/interfaces/application_context_required.py b/py_spring_core/core/interfaces/application_context_required.py index 24ac219..83d7729 100644 --- a/py_spring_core/core/interfaces/application_context_required.py +++ b/py_spring_core/core/interfaces/application_context_required.py @@ -1,7 +1,8 @@ from typing import Optional -from py_spring_core.core.application.context.application_context import \ - ApplicationContext +from py_spring_core.core.application.context.application_context import ( + ApplicationContext, +) class ApplicationContextRequired: diff --git a/py_spring_core/event/application_event_handler_registry.py b/py_spring_core/event/application_event_handler_registry.py index e618434..7d5d0fb 100644 --- a/py_spring_core/event/application_event_handler_registry.py +++ b/py_spring_core/event/application_event_handler_registry.py @@ -5,8 +5,9 @@ from pydantic import BaseModel from py_spring_core.core.entities.component import Component -from py_spring_core.core.interfaces.application_context_required import \ - ApplicationContextRequired +from py_spring_core.core.interfaces.application_context_required import ( + ApplicationContextRequired, +) from py_spring_core.event.commons import ApplicationEvent, EventQueue EventHandlerT = Callable[[Component, ApplicationEvent], None] diff --git a/py_spring_core/event/application_event_publisher.py b/py_spring_core/event/application_event_publisher.py index 4383d1d..3012d75 100644 --- a/py_spring_core/event/application_event_publisher.py +++ b/py_spring_core/event/application_event_publisher.py @@ -2,7 +2,9 @@ from py_spring_core.core.entities.component import Component from py_spring_core.event.application_event_handler_registry import ( - ApplicationEvent, ApplicationEventHandlerRegistry) + ApplicationEvent, + ApplicationEventHandlerRegistry, +) from py_spring_core.event.commons import EventQueue T = TypeVar("T", bound=ApplicationEvent) diff --git a/tests/test_application_context.py b/tests/test_application_context.py index cd1a90d..52eaa7e 100644 --- a/tests/test_application_context.py +++ b/tests/test_application_context.py @@ -2,11 +2,12 @@ from fastapi import FastAPI from py_spring_core.core.application.context.application_context import ( - ApplicationContext, ApplicationContextConfig) + ApplicationContext, + ApplicationContextConfig, +) from py_spring_core.core.entities.bean_collection import BeanCollection from py_spring_core.core.entities.component import Component -from py_spring_core.core.entities.controllers.rest_controller import \ - RestController +from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.properties.properties import Properties diff --git a/tests/test_bean_collection.py b/tests/test_bean_collection.py index 3c73fbe..6a35c21 100644 --- a/tests/test_bean_collection.py +++ b/tests/test_bean_collection.py @@ -1,7 +1,6 @@ import pytest -from py_spring_core.core.entities.bean_collection import (BeanCollection, - BeanView) +from py_spring_core.core.entities.bean_collection import BeanCollection, BeanView from py_spring_core.core.entities.component import Component diff --git a/tests/test_component_features.py b/tests/test_component_features.py index 583110c..a696176 100644 --- a/tests/test_component_features.py +++ b/tests/test_component_features.py @@ -5,7 +5,9 @@ from fastapi import FastAPI from py_spring_core.core.application.context.application_context import ( - ApplicationContext, ApplicationContextConfig) + ApplicationContext, + ApplicationContextConfig, +) from py_spring_core.core.entities.component import Component, ComponentScope diff --git a/tests/test_entity_provider.py b/tests/test_entity_provider.py index e632ed4..bda16b8 100644 --- a/tests/test_entity_provider.py +++ b/tests/test_entity_provider.py @@ -2,9 +2,12 @@ from fastapi import FastAPI from py_spring_core.core.application.context.application_context import ( - ApplicationContext, InvalidDependencyError) -from py_spring_core.core.application.context.application_context_config import \ - ApplicationContextConfig + ApplicationContext, + InvalidDependencyError, +) +from py_spring_core.core.application.context.application_context_config import ( + ApplicationContextConfig, +) from py_spring_core.core.entities.component import Component from py_spring_core.core.entities.entity_provider import EntityProvider diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 08e6156..1d71808 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -6,8 +6,9 @@ from starlette.middleware.base import BaseHTTPMiddleware from py_spring_core.core.entities.middlewares.middleware import Middleware -from py_spring_core.core.entities.middlewares.middleware_registry import \ - MiddlewareRegistry +from py_spring_core.core.entities.middlewares.middleware_registry import ( + MiddlewareRegistry, +) class TestMiddleware: diff --git a/tests/test_properties_loader.py b/tests/test_properties_loader.py index f79c970..db0e606 100644 --- a/tests/test_properties_loader.py +++ b/tests/test_properties_loader.py @@ -3,7 +3,9 @@ from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.properties.properties_loader import ( - InvalidPropertiesKeyError, _PropertiesLoader) + InvalidPropertiesKeyError, + _PropertiesLoader, +) class TestPropertiesLoader: From 3a3be87811e8731cad75d311ab89ff87ba738f90 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Fri, 18 Jul 2025 17:07:07 +0800 Subject: [PATCH 29/42] Add should_skip method to Middleware class for conditional request processing (#15) --- .../core/application/py_spring_application.py | 2 +- .../core/entities/middlewares/middleware.py | 16 + tests/test_middleware.py | 227 ++++++++++++++ tests/test_middleware_single_execution.py | 291 ++++++++++++++++++ 4 files changed, 535 insertions(+), 1 deletion(-) create mode 100644 tests/test_middleware_single_execution.py diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index e3f8c42..f02adf3 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -228,7 +228,6 @@ def __init_controllers(self) -> None: controller._register_decorated_routes(routes) router = controller.get_router() self.fastapi.include_router(router) - self.__init_middlewares() logger.debug(f"[CONTROLLER INIT] Controller {name} initialized") def __init_middlewares(self) -> None: @@ -298,6 +297,7 @@ def run(self) -> None: self.__configure_logging() self.__init_app() self.__init_controllers() + self.__init_middlewares() if self.app_config.server_config.enabled: self.__run_server() finally: diff --git a/py_spring_core/core/entities/middlewares/middleware.py b/py_spring_core/core/entities/middlewares/middleware.py index dd89e70..d861309 100644 --- a/py_spring_core/core/entities/middlewares/middleware.py +++ b/py_spring_core/core/entities/middlewares/middleware.py @@ -11,6 +11,18 @@ class Middleware(BaseHTTPMiddleware): Simpler to use, only need to implement the process_request method """ + def should_skip(self, request: Request) -> bool: + """ + Method to determine if the middleware should be skipped + + Args: + request: FastAPI request object + + Returns: + bool: True if the middleware should be skipped, False otherwise, default is False + """ + return False + @abstractmethod async def process_request(self, request: Request) -> Response | None: """ @@ -30,7 +42,11 @@ async def dispatch( ) -> Response: """ Middleware dispatch method, automatically called by FastAPI + If should_skip returns True, the middleware will be skipped """ + if self.should_skip(request): + return await call_next(request) + # First execute custom request processing logic response = await self.process_request(request) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 1d71808..e692bcc 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -140,6 +140,233 @@ async def process_request(self, request: Request) -> Response | None: assert received_request == mock_request + def test_should_skip_default_returns_false(self, mock_request): + """ + Test that should_skip method returns False by default. + + This test verifies that: + 1. The default implementation of should_skip returns False + 2. This allows the middleware to process all requests by default + """ + middleware = Middleware(app=Mock()) + result = middleware.should_skip(mock_request) + assert result is False + + def test_should_skip_can_be_overridden(self, mock_request): + """ + Test that should_skip method can be overridden in subclasses. + + This test verifies that: + 1. Subclasses can override should_skip to provide custom skip logic + 2. The overridden method is called with the correct request parameter + """ + class SkippingMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + return request.method == "GET" + + middleware = SkippingMiddleware(app=Mock()) + result = middleware.should_skip(mock_request) + assert result is True + + @pytest.mark.asyncio + async def test_dispatch_skips_middleware_when_should_skip_returns_true( + self, mock_request, mock_call_next + ): + """ + Test that dispatch skips middleware processing when should_skip returns True. + + This test verifies that: + 1. When should_skip returns True, process_request is not called + 2. The request is passed directly to call_next + 3. The response from call_next is returned + """ + expected_response = Response(content="skipped response", status_code=200) + mock_call_next.return_value = expected_response + + class SkippingMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + return True + + async def process_request(self, request: Request) -> Response | None: + # This should never be called when should_skip returns True + raise AssertionError("process_request should not be called") + + middleware = SkippingMiddleware(app=Mock()) + result = await middleware.dispatch(mock_request, mock_call_next) + + # Verify call_next was called with the request + mock_call_next.assert_called_once_with(mock_request) + # Verify the response from call_next is returned + assert result == expected_response + + @pytest.mark.asyncio + async def test_dispatch_processes_middleware_when_should_skip_returns_false( + self, mock_request, mock_call_next + ): + """ + Test that dispatch processes middleware when should_skip returns False. + + This test verifies that: + 1. When should_skip returns False, process_request is called + 2. The middleware processing logic is executed + 3. The normal dispatch flow continues + """ + expected_response = Response(content="processed response", status_code=200) + mock_call_next.return_value = expected_response + + process_request_called = False + + class ProcessingMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + return False + + async def process_request(self, request: Request) -> Response | None: + nonlocal process_request_called + process_request_called = True + return None + + middleware = ProcessingMiddleware(app=Mock()) + result = await middleware.dispatch(mock_request, mock_call_next) + + # Verify process_request was called + assert process_request_called is True + # Verify call_next was called + mock_call_next.assert_called_once_with(mock_request) + # Verify the response from call_next is returned + assert result == expected_response + + def test_should_skip_receives_correct_request_parameter(self, mock_request): + """ + Test that should_skip method receives the correct request parameter. + + This test verifies that: + 1. The should_skip method receives the exact same request object + 2. The request parameter is passed correctly + """ + received_request = None + + class TestMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + nonlocal received_request + received_request = request + return False + + middleware = TestMiddleware(app=Mock()) + middleware.should_skip(mock_request) + + assert received_request == mock_request + + @pytest.mark.asyncio + async def test_dispatch_with_conditional_skip_logic(self, mock_request, mock_call_next): + """ + Test dispatch with conditional skip logic based on request properties. + + This test verifies that: + 1. should_skip can use request properties to make skip decisions + 2. The skip logic works correctly in the dispatch flow + 3. Both skip and process paths work as expected + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + class ConditionalMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + # Skip GET requests, process others + return request.method == "GET" + + async def process_request(self, request: Request) -> Response | None: + # This should only be called for non-GET requests + return Response(content="processed", status_code=202) + + middleware = ConditionalMiddleware(app=Mock()) + + # Test with GET request (should skip) + mock_request.method = "GET" + result = await middleware.dispatch(mock_request, mock_call_next) + + assert result == expected_response + mock_call_next.assert_called_once_with(mock_request) + + # Reset mock for next test + mock_call_next.reset_mock() + mock_call_next.return_value = expected_response + + # Test with POST request (should process) + mock_request.method = "POST" + result = await middleware.dispatch(mock_request, mock_call_next) + + assert result.body == b"processed" + assert result.status_code == 202 + mock_call_next.assert_not_called() + + @pytest.mark.asyncio + async def test_dispatch_with_url_based_skip_logic(self, mock_request, mock_call_next): + """ + Test dispatch with URL-based skip logic. + + This test verifies that: + 1. should_skip can use request URL to make skip decisions + 2. URL-based filtering works correctly + 3. The middleware processes only relevant requests + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + class URLBasedMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + # Skip requests to /health endpoint + return str(request.url).endswith("/health") + + async def process_request(self, request: Request) -> Response | None: + return Response(content="processed", status_code=202) + + middleware = URLBasedMiddleware(app=Mock()) + + # Test with health endpoint (should skip) + mock_request.url = "http://test.com/health" + result = await middleware.dispatch(mock_request, mock_call_next) + + assert result == expected_response + mock_call_next.assert_called_once_with(mock_request) + + # Reset mock for next test + mock_call_next.reset_mock() + mock_call_next.return_value = expected_response + + # Test with other endpoint (should process) + mock_request.url = "http://test.com/api/users" + result = await middleware.dispatch(mock_request, mock_call_next) + + assert result.body == b"processed" + assert result.status_code == 202 + mock_call_next.assert_not_called() + + def test_should_skip_method_signature(self): + """ + Test that should_skip method has the correct signature. + + This test verifies that: + 1. should_skip is an instance method + 2. It takes a Request parameter + 3. It returns a boolean value + """ + middleware = Middleware(app=Mock()) + + # Check that should_skip is a method + assert hasattr(middleware, 'should_skip') + assert callable(middleware.should_skip) + + # Check that it's an instance method (not a class method or static method) + import inspect + sig = inspect.signature(middleware.should_skip) + params = list(sig.parameters.keys()) + + # Should have 'request' parameter (self is automatically handled by Python) + assert params == ['request'] + + # Check return type annotation + assert sig.return_annotation == bool + class TestMiddlewareRegistry: """Test suite for the MiddlewareRegistry abstract class.""" diff --git a/tests/test_middleware_single_execution.py b/tests/test_middleware_single_execution.py new file mode 100644 index 0000000..5ab7697 --- /dev/null +++ b/tests/test_middleware_single_execution.py @@ -0,0 +1,291 @@ +import pytest +from unittest.mock import AsyncMock, Mock +from fastapi import FastAPI, Request, Response +from fastapi.testclient import TestClient +from collections import defaultdict +import threading +import time + +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry + + +class SingleExecutionMiddleware(Middleware): + """ + Test middleware that ensures it only processes each request once + by tracking processed requests and preventing duplicate processing + """ + + def __init__(self, app): + super().__init__(app) + self.processed_requests = set() + self.execution_count = defaultdict(int) + self._lock = threading.Lock() + + async def process_request(self, request: Request) -> Response | None: + # Create a unique identifier for this request + request_id = id(request) + + # Use thread lock to ensure thread safety + with self._lock: + # Check if this request has already been processed + if request_id in self.processed_requests: + # This request has already been processed, skip it + return None + + # Mark this request as processed + self.processed_requests.add(request_id) + self.execution_count[request_id] += 1 + + # Return None to continue processing + return None + + +class TestMiddlewareSingleExecution: + """Test suite for ensuring middleware executes only once per request.""" + + @pytest.fixture + def mock_request(self): + """Fixture that provides a mock FastAPI request.""" + request = Mock(spec=Request) + request.method = "GET" + request.url = "http://test.com/api" + return request + + @pytest.fixture + def mock_call_next(self): + """Fixture that provides a mock call_next function.""" + return AsyncMock() + + @pytest.mark.asyncio + async def test_middleware_executes_only_once_per_request(self, mock_request, mock_call_next): + """ + Test that middleware executes only once per request. + + This test verifies that: + 1. When dispatch is called multiple times with the same request, + process_request logic is only executed once + 2. The execution count for the request remains at 1 + 3. The request is only processed once + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + middleware = SingleExecutionMiddleware(app=Mock()) + + # Call dispatch multiple times with the same request + for _ in range(3): + result = await middleware.dispatch(mock_request, mock_call_next) + + # Verify that process_request logic was only executed once for this request + request_id = id(mock_request) + assert middleware.execution_count[request_id] == 1 + + # Verify that the request was added to processed set + assert request_id in middleware.processed_requests + + # Verify call_next was called the expected number of times + assert mock_call_next.call_count == 3 + + # Verify the response is correct + assert result == expected_response + + @pytest.mark.asyncio + async def test_middleware_executes_once_per_different_request(self, mock_call_next): + """ + Test that middleware executes once for each different request. + + This test verifies that: + 1. Each unique request is processed exactly once + 2. Different requests are tracked separately + 3. The execution count is correct for each request + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + middleware = SingleExecutionMiddleware(app=Mock()) + + # Create multiple different requests + requests = [] + for i in range(3): + request = Mock(spec=Request) + request.method = "GET" + request.url = f"http://test.com/api/{i}" + requests.append(request) + + # Process each request once + for request in requests: + await middleware.dispatch(request, mock_call_next) + + # Verify each request was processed exactly once + for request in requests: + request_id = id(request) + assert middleware.execution_count[request_id] == 1 + assert request_id in middleware.processed_requests + + # Verify total number of processed requests + assert len(middleware.processed_requests) == 3 + assert sum(middleware.execution_count.values()) == 3 + + @pytest.mark.asyncio + async def test_middleware_handles_early_return_correctly(self, mock_request, mock_call_next): + """ + Test that middleware handles early return correctly while maintaining single execution. + + This test verifies that: + 1. When middleware returns a response early, it still counts as executed once + 2. The request is tracked even when middleware returns early + 3. call_next is not called when middleware returns early + """ + middleware_response = Response(content="middleware response", status_code=403) + + class EarlyReturnMiddleware(SingleExecutionMiddleware): + async def process_request(self, request: Request) -> Response | None: + # Call parent to track execution + await super().process_request(request) + # Return early response + return middleware_response + + middleware = EarlyReturnMiddleware(app=Mock()) + + # Call dispatch + result = await middleware.dispatch(mock_request, mock_call_next) + + # Verify that process_request was called once + request_id = id(mock_request) + assert middleware.execution_count[request_id] == 1 + assert request_id in middleware.processed_requests + + # Verify call_next was not called (early return) + mock_call_next.assert_not_called() + + # Verify the early response is returned + assert result == middleware_response + + @pytest.mark.asyncio + async def test_middleware_with_skip_logic_maintains_single_execution(self, mock_request, mock_call_next): + """ + Test that middleware with skip logic maintains single execution tracking. + + This test verifies that: + 1. When should_skip returns True, the request is still tracked + 2. The execution count remains at 0 for skipped requests + 3. The request is still added to the processed set + """ + expected_response = Response(content="skipped response", status_code=200) + mock_call_next.return_value = expected_response + + class SkippingMiddleware(SingleExecutionMiddleware): + def should_skip(self, request: Request) -> bool: + return True + + async def process_request(self, request: Request) -> Response | None: + # This should never be called when should_skip returns True + raise AssertionError("process_request should not be called when should_skip returns True") + + middleware = SkippingMiddleware(app=Mock()) + + # Call dispatch + result = await middleware.dispatch(mock_request, mock_call_next) + + # Verify that process_request was not called (due to skip) + request_id = id(mock_request) + assert request_id not in middleware.processed_requests + + # Verify call_next was called + mock_call_next.assert_called_once_with(mock_request) + assert result == expected_response + + @pytest.mark.asyncio + async def test_middleware_thread_safety(self, mock_call_next): + """ + Test that middleware is thread-safe when processing multiple requests concurrently. + + This test verifies that: + 1. Multiple threads can safely access the middleware + 2. Each request is processed exactly once even under concurrent access + 3. No race conditions occur + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + middleware = SingleExecutionMiddleware(app=Mock()) + + # Create multiple requests + requests = [] + for i in range(10): + request = Mock(spec=Request) + request.method = "GET" + request.url = f"http://test.com/api/{i}" + requests.append(request) + + # Process requests concurrently + import asyncio + + async def process_request(request): + return await middleware.dispatch(request, mock_call_next) + + # Process all requests concurrently + results = await asyncio.gather(*[process_request(req) for req in requests]) + + # Verify each request was processed exactly once + for request in requests: + request_id = id(request) + assert middleware.execution_count[request_id] == 1 + assert request_id in middleware.processed_requests + + # Verify all responses are correct + for result in results: + assert result == expected_response + + # Verify total number of processed requests + assert len(middleware.processed_requests) == 10 + assert sum(middleware.execution_count.values()) == 10 + + @pytest.mark.asyncio + async def test_middleware_integration_with_fastapi(self): + """ + Test middleware single execution in a real FastAPI application. + + This test verifies that: + 1. Middleware executes only once per request in a real FastAPI app + 2. Multiple requests are handled correctly + 3. The tracking works across different endpoints + """ + app = FastAPI() + + class TestMiddleware(SingleExecutionMiddleware): + pass + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [TestMiddleware] + + # Apply middleware to app + registry = TestRegistry() + app = registry.apply_middlewares(app) + + # Create test endpoints + @app.get("/test1") + async def test_endpoint1(): + return {"message": "test1"} + + @app.get("/test2") + async def test_endpoint2(): + return {"message": "test2"} + + # Create test client + client = TestClient(app) + + # Make multiple requests + response1 = client.get("/test1") + response2 = client.get("/test2") + response3 = client.get("/test1") # Same endpoint as first request + + # Verify responses + assert response1.status_code == 200 + assert response2.status_code == 200 + assert response3.status_code == 200 + + # The key point is that the middleware doesn't interfere with normal operation + # and each request is processed exactly once through the middleware chain \ No newline at end of file From 1c23e9b6995d25accbadf9c5f3e39b7e6ba24031 Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 18 Jul 2025 17:11:16 +0800 Subject: [PATCH 30/42] Update version to 0.0.20 in __init__.py --- py_spring_core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 6c7775d..acde155 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -22,7 +22,7 @@ from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent -__version__ = "0.0.19" +__version__ = "0.0.20" __all__ = [ "PySpringApplication", From a04c98aa5a9a53e055be90f7520735c4c33f4fdd Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 18 Jul 2025 17:14:04 +0800 Subject: [PATCH 31/42] [skip ci] Add test execution step to PyPI deployment workflow --- .github/workflows/pypi-deployment.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/pypi-deployment.yaml b/.github/workflows/pypi-deployment.yaml index 8b61d92..7c8bbb9 100644 --- a/.github/workflows/pypi-deployment.yaml +++ b/.github/workflows/pypi-deployment.yaml @@ -28,6 +28,10 @@ jobs: run: | pdm install --prod + - name: Run tests + run: | + pdm run pytest + - name: Build the package run: | pdm build From 3e89037d31a65cfb7deecad45e9cd0b843700bd8 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Fri, 18 Jul 2025 17:17:26 +0800 Subject: [PATCH 32/42] Update version to 0.0.21 in __init__.py (#16) --- py_spring_core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index acde155..49290b2 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -22,7 +22,7 @@ from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent -__version__ = "0.0.20" +__version__ = "0.0.21" __all__ = [ "PySpringApplication", From 89e86942b643853e76aac5313098074d8622b047 Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 18 Jul 2025 17:19:48 +0800 Subject: [PATCH 33/42] Update PyPI deployment workflow to install all dependencies instead of only production ones --- .github/workflows/pypi-deployment.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pypi-deployment.yaml b/.github/workflows/pypi-deployment.yaml index 7c8bbb9..16f19ed 100644 --- a/.github/workflows/pypi-deployment.yaml +++ b/.github/workflows/pypi-deployment.yaml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | - pdm install --prod + pdm install - name: Run tests run: | From 40bd4d3198484d7b4ca51df23da5ec3dc718d397 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Sun, 20 Jul 2025 00:44:24 +0800 Subject: [PATCH 34/42] # Module Cache Mechanism Implementation (#17) --- py_spring_core/__init__.py | 2 +- py_spring_core/commons/class_scanner.py | 34 ++- py_spring_core/commons/module_importer.py | 135 +++++++++++ .../core/application/application_config.py | 1 + .../core/application/py_spring_application.py | 4 +- py_spring_core/core/utils.py | 72 ++---- tests/test_component_features.py | 1 - tests/test_module_import_cache.py | 166 +++++++++++++ tests/test_module_importer.py | 227 ++++++++++++++++++ 9 files changed, 573 insertions(+), 69 deletions(-) create mode 100644 py_spring_core/commons/module_importer.py create mode 100644 tests/test_module_import_cache.py create mode 100644 tests/test_module_importer.py diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 49290b2..306667d 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -22,7 +22,7 @@ from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent -__version__ = "0.0.21" +__version__ = "0.0.22" __all__ = [ "PySpringApplication", diff --git a/py_spring_core/commons/class_scanner.py b/py_spring_core/commons/class_scanner.py index cb13bff..178d366 100644 --- a/py_spring_core/commons/class_scanner.py +++ b/py_spring_core/commons/class_scanner.py @@ -1,9 +1,11 @@ import ast -import importlib.util -from typing import Iterable, Optional, Type +import re +from typing import Any, Iterable, Type from loguru import logger +from .module_importer import ModuleImporter + class ClassScanner: """ @@ -17,6 +19,8 @@ class ClassScanner: def __init__(self, file_paths: Iterable[str]) -> None: self.file_paths = file_paths self.scanned_classes: dict[str, dict[str, Type[object]]] = {} + # Use ModuleImporter for handling module imports + self._module_importer = ModuleImporter() def extract_classes_from_file(self, file_path: str) -> dict[str, Type[object]]: with open(file_path, "r") as file: @@ -44,8 +48,12 @@ def _extract_classes_from_file_content( return class_objects - def scan_classes_for_file_paths(self) -> None: + def scan_classes_for_file_paths(self, exclude_file_patterns: Iterable[str]) -> None: for file_path in self.file_paths: + if any(re.match(pattern, file_path) for pattern in exclude_file_patterns): + logger.debug(f"[EXCLUDED FILE] {file_path} is excluded") + continue + object_cls_dict: dict[str, Type[object]] = self.extract_classes_from_file( file_path ) @@ -53,14 +61,12 @@ def scan_classes_for_file_paths(self) -> None: def import_class_from_file( self, file_path: str, class_name: str - ) -> Optional[Type[object]]: - spec = importlib.util.spec_from_file_location(class_name, file_path) - if spec is None: + ) -> Type[object] | None: + # Use ModuleImporter to handle module import + module = self._module_importer.import_module_from_path(file_path) + if module is None: return None - module = importlib.util.module_from_spec(spec) - if spec.loader is None: - return None - spec.loader.exec_module(module) + cls = getattr(module, class_name, None) return cls @@ -79,3 +85,11 @@ def display_classes(self) -> None: repr += f" Class: {class_name}\n" logger.debug(repr) + + def clear_module_cache(self) -> None: + """Clear the module cache. Useful for testing or when you need to force re-import.""" + self._module_importer.clear_cache() + + def get_cache_size(self) -> int: + """Get the number of cached modules.""" + return self._module_importer.get_cache_size() diff --git a/py_spring_core/commons/module_importer.py b/py_spring_core/commons/module_importer.py new file mode 100644 index 0000000..44a6573 --- /dev/null +++ b/py_spring_core/commons/module_importer.py @@ -0,0 +1,135 @@ +import importlib.util +import inspect +from pathlib import Path +from typing import Any, Iterable, Type, Optional + +from loguru import logger + + +class ModuleImporter: + """ + A class that handles dynamic module importing with caching capabilities. + Provides functionality to import modules from file paths and extract classes from them. + """ + + def __init__(self) -> None: + # Module cache to prevent duplicate imports + self._module_cache: dict[str, Any] = {} + + def import_module_from_path(self, file_path: str) -> Optional[Any]: + """ + Import a module from a file path with caching. + + Args: + file_path (str): The file path of the module to import. + + Returns: + Optional[Any]: The imported module or None if import fails. + """ + resolved_path = Path(file_path).resolve() + cache_key = str(resolved_path) + module_name = resolved_path.stem + + # Check if module is already cached + if cache_key in self._module_cache: + logger.debug(f"[MODULE CACHE] Using cached module: {module_name}") + return self._module_cache[cache_key] + + logger.info(f"[MODULE IMPORT] Import module path: {resolved_path}") + + # Create a module specification + spec = importlib.util.spec_from_file_location(module_name, resolved_path) + if spec is None: + logger.warning(f"[MODULE IMPORT] Could not create spec for {module_name}") + return None + + # Create a new module based on the specification + module = importlib.util.module_from_spec(spec) + if spec.loader is None: + logger.warning(f"[MODULE IMPORT] No loader found for {module_name}") + return None + + # Execute the module in its own namespace + logger.info(f"[MODULE IMPORT] Import module: {module_name}") + try: + spec.loader.exec_module(module) + logger.success(f"[MODULE IMPORT] Successfully imported {module_name}") + # Cache the module + self._module_cache[cache_key] = module + return module + except Exception as error: + logger.warning(f"[MODULE IMPORT] Failed to import {module_name}: {error}") + return None + + def extract_classes_from_module(self, module: Any) -> list[Type[object]]: + """ + Extract all classes from a module. + + Args: + module (Any): The module to extract classes from. + + Returns: + list[Type[object]]: List of classes found in the module. + """ + loaded_classes = [] + for attr in dir(module): + obj = getattr(module, attr) + if attr.startswith("__"): + continue + if not inspect.isclass(obj): + continue + loaded_classes.append(obj) + return loaded_classes + + def import_classes_from_paths( + self, + file_paths: Iterable[str], + target_subclasses: Iterable[Type[object]] = [], + ignore_errors: bool = True + ) -> set[Type[object]]: + """ + Import classes from multiple file paths with optional filtering. + + Args: + file_paths (Iterable[str]): The file paths of the modules to import. + target_subclasses (Iterable[Type[object]], optional): Target subclasses to filter. Defaults to []. + ignore_errors (bool, optional): Whether to ignore import errors. Defaults to True. + + Returns: + set[Type[object]]: Set of imported classes. + """ + all_loaded_classes: list[Type[object]] = [] + + for file_path in file_paths: + module = self.import_module_from_path(file_path) + if module is None: + if not ignore_errors: + raise ImportError(f"Failed to import module from {file_path}") + continue + + loaded_classes = self.extract_classes_from_module(module) + all_loaded_classes.extend(loaded_classes) + + returned_target_classes: set[Type[object]] = set() + + # If no target subclasses specified, return all loaded classes + if not target_subclasses: + returned_target_classes = set(all_loaded_classes) + else: + # Filter classes based on target subclasses + for target_cls in target_subclasses: + for loaded_class in all_loaded_classes: + if loaded_class in target_subclasses: + continue + if issubclass(loaded_class, target_cls): + returned_target_classes.add(loaded_class) + + return returned_target_classes + + def clear_cache(self) -> None: + """Clear the module cache. Useful for testing or when you need to force re-import.""" + self._module_cache.clear() + + def get_cache_size(self) -> int: + """Get the number of cached modules.""" + return len(self._module_cache) \ No newline at end of file diff --git a/py_spring_core/core/application/application_config.py b/py_spring_core/core/application/application_config.py index 0cd55cd..e86b00c 100644 --- a/py_spring_core/core/application/application_config.py +++ b/py_spring_core/core/application/application_config.py @@ -36,6 +36,7 @@ class ApplicationConfig(BaseModel): model_config = ConfigDict(protected_namespaces=()) app_src_target_dir: str + exclude_file_patterns: list[str] = Field(default_factory=lambda: [r".*/models\.py$"]) server_config: ServerConfig properties_file_path: str loguru_config: LoguruConfig diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index f02adf3..48f1850 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -123,7 +123,9 @@ def _get_system_managed_classes(self) -> Iterable[Type[Component]]: return [ApplicationEventPublisher, ApplicationEventHandlerRegistry] def _scan_classes_for_project(self) -> Iterable[Type[object]]: - self.app_class_scanner.scan_classes_for_file_paths() + self.app_class_scanner.scan_classes_for_file_paths( + self.app_config.exclude_file_patterns + ) return self.app_class_scanner.get_classes() def _register_app_entities(self, classes: Iterable[Type[object]]) -> None: diff --git a/py_spring_core/core/utils.py b/py_spring_core/core/utils.py index 0f549b7..5421096 100644 --- a/py_spring_core/core/utils.py +++ b/py_spring_core/core/utils.py @@ -1,11 +1,14 @@ -import importlib.util import inspect from abc import ABC -from pathlib import Path from typing import Any, Iterable, Type from loguru import logger +from ..commons.module_importer import ModuleImporter + +# Global module importer instance +_module_importer = ModuleImporter() + def dynamically_import_modules( module_paths: Iterable[str], @@ -18,64 +21,21 @@ def dynamically_import_modules( Args: module_paths (Iterable[str]): The file paths of the modules to import. is_ignore_error (bool, optional): Whether to ignore any errors that occur during the import process. Defaults to True. + target_subclasses (Iterable[Type[object]], optional): Target subclasses to filter. Defaults to []. Raises: Exception: If an error occurs during the import process and `is_ignore_error` is False. """ - all_loaded_classes: list[Type[object]] = [] - - for module_path in module_paths: - file_path = Path(module_path).resolve() - module_name = file_path.stem - logger.info(f"[MODULE IMPORT] Import module path: {file_path}") - # Create a module specification - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - logger.warning( - f"[DYNAMICALLY MODULE IMPORT] Could not create spec for {module_name}" - ) - continue - - # Create a new module based on the specification - module = importlib.util.module_from_spec(spec) - if spec.loader is None: - logger.warning( - f"[DYNAMICALLY MODULE IMPORT] No loader found for {module_name}" - ) - continue - - # Execute the module in its own namespace - - logger.info(f"[DYNAMICALLY MODULE IMPORT] Import module: {module_name}") - try: - spec.loader.exec_module(module) - logger.success( - f"[DYNAMICALLY MODULE IMPORT] Successfully imported {module_name}" - ) - except Exception as error: - logger.warning(error) - if not is_ignore_error: - raise error - - loaded_classes = [] - for attr in dir(module): - obj = getattr(module, attr) - if attr.startswith("__"): - continue - if not inspect.isclass(obj): - continue - loaded_classes.append(obj) - all_loaded_classes.extend(loaded_classes) - - returned_target_classes: set[Type[object]] = set() - for target_cls in target_subclasses: - for loaded_class in all_loaded_classes: - if loaded_class in target_subclasses: - continue - if issubclass(loaded_class, target_cls): - returned_target_classes.add(loaded_class) - - return returned_target_classes + return _module_importer.import_classes_from_paths( + file_paths=module_paths, + target_subclasses=target_subclasses, + ignore_errors=is_ignore_error + ) + + +def clear_module_cache() -> None: + """Clear the global module cache. Useful for testing or when you need to force re-import.""" + _module_importer.clear_cache() def get_unimplemented_abstract_methods(cls: Type[Any]) -> set[str]: diff --git a/tests/test_component_features.py b/tests/test_component_features.py index a696176..4221ac9 100644 --- a/tests/test_component_features.py +++ b/tests/test_component_features.py @@ -1,4 +1,3 @@ -from abc import ABC from typing import Annotated import pytest diff --git a/tests/test_module_import_cache.py b/tests/test_module_import_cache.py new file mode 100644 index 0000000..7a10127 --- /dev/null +++ b/tests/test_module_import_cache.py @@ -0,0 +1,166 @@ +import tempfile +import os +from pathlib import Path +from unittest.mock import patch + + +from py_spring_core.core.utils import dynamically_import_modules, clear_module_cache +from py_spring_core.commons.class_scanner import ClassScanner + + +class TestModuleImportCache: + """Test module import caching to prevent duplicate imports.""" + + def setup_method(self): + """Clear module cache before each test.""" + clear_module_cache() + + def test_dynamically_import_modules_cache(self): + """Test that dynamically_import_modules uses cache to prevent duplicate imports.""" + # Create a temporary Python file with a class + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + def __init__(self): + self.value = "test" +""") + temp_file_path = f.name + + try: + # First import - get all classes without filtering + classes1 = dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Second import of the same file + classes2 = dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Both should return the same classes + assert len(classes1) == 1 + assert len(classes2) == 1 + assert classes1 == classes2 + + # The class should be the same object (not a duplicate) + class1 = list(classes1)[0] + class2 = list(classes2)[0] + assert class1 is class2 + + finally: + os.unlink(temp_file_path) + + def test_class_scanner_cache(self): + """Test that ClassScanner uses cache to prevent duplicate imports.""" + # Create a temporary Python file with a class + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + def __init__(self): + self.value = "test" +""") + temp_file_path = f.name + + try: + scanner = ClassScanner([temp_file_path]) + + # First scan + scanner.scan_classes_for_file_paths([]) + classes1 = list(scanner.get_classes()) + + # Second scan + scanner.scan_classes_for_file_paths([]) + classes2 = list(scanner.get_classes()) + + # Both should return the same classes + assert len(classes1) == 1 + assert len(classes2) == 1 + assert classes1 == classes2 + + # The class should be the same object (not a duplicate) + assert classes1[0] is classes2[0] + + finally: + os.unlink(temp_file_path) + + def test_clear_module_cache(self): + """Test that clear_module_cache works correctly.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + pass +""") + temp_file_path = f.name + + try: + # First import + classes1 = dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Clear cache + clear_module_cache() + + # Second import after clearing cache + classes2 = dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Classes should be different objects after clearing cache + class1 = list(classes1)[0] + class2 = list(classes2)[0] + assert class1 is not class2 + + finally: + os.unlink(temp_file_path) + + def test_class_scanner_clear_cache(self): + """Test that ClassScanner clear_module_cache works correctly.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + pass +""") + temp_file_path = f.name + + try: + scanner = ClassScanner([temp_file_path]) + + # First scan + scanner.scan_classes_for_file_paths([]) + classes1 = list(scanner.get_classes()) + + # Clear cache + scanner.clear_module_cache() + + # Second scan after clearing cache + scanner.scan_classes_for_file_paths([]) + classes2 = list(scanner.get_classes()) + + # Classes should be different objects after clearing cache + assert classes1[0] is not classes2[0] + + finally: + os.unlink(temp_file_path) + + @patch('py_spring_core.commons.module_importer.logger') + def test_cache_logging(self, mock_logger): + """Test that cache usage is properly logged.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + pass +""") + temp_file_path = f.name + + try: + # First import (should log import) + dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Second import (should log cache usage) + dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Check that debug log for cache usage was called + # Get the actual module name from the temp file + module_name = Path(temp_file_path).stem + mock_logger.debug.assert_called_with( + f"[MODULE CACHE] Using cached module: {module_name}" + ) + + finally: + os.unlink(temp_file_path) \ No newline at end of file diff --git a/tests/test_module_importer.py b/tests/test_module_importer.py new file mode 100644 index 0000000..eea1fd3 --- /dev/null +++ b/tests/test_module_importer.py @@ -0,0 +1,227 @@ +import tempfile +import os +from pathlib import Path +from typing import Type + +import pytest + +from py_spring_core.commons.module_importer import ModuleImporter + + +class TestModuleImporter: + """Test ModuleImporter class functionality.""" + + def setup_method(self): + """Clear module cache before each test.""" + self.importer = ModuleImporter() + + def test_import_module_from_path(self): + """Test importing a module from a file path.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + def __init__(self): + self.value = "test" + +def test_function(): + return "test" +""") + temp_file_path = f.name + + try: + module = self.importer.import_module_from_path(temp_file_path) + + assert module is not None + assert hasattr(module, 'TestClass') + assert hasattr(module, 'test_function') + + # Test that the class can be instantiated + test_instance = module.TestClass() + assert test_instance.value == "test" + + finally: + os.unlink(temp_file_path) + + def test_import_module_from_path_caching(self): + """Test that modules are cached and reused.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + pass +""") + temp_file_path = f.name + + try: + # First import + module1 = self.importer.import_module_from_path(temp_file_path) + cache_size1 = self.importer.get_cache_size() + + # Second import + module2 = self.importer.import_module_from_path(temp_file_path) + cache_size2 = self.importer.get_cache_size() + + # Should be the same module object + assert module1 is module2 + # Cache size should not increase + assert cache_size1 == cache_size2 + + finally: + os.unlink(temp_file_path) + + def test_extract_classes_from_module(self): + """Test extracting classes from a module.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class ClassA: + pass + +class ClassB: + pass + +def function(): + pass + +CONSTANT = 42 +""") + temp_file_path = f.name + + try: + module = self.importer.import_module_from_path(temp_file_path) + classes = self.importer.extract_classes_from_module(module) + + # Should find 2 classes + assert len(classes) == 2 + + # Check class names + class_names = [cls.__name__ for cls in classes] + assert 'ClassA' in class_names + assert 'ClassB' in class_names + + finally: + os.unlink(temp_file_path) + + def test_import_classes_from_paths(self): + """Test importing classes from multiple file paths.""" + # Create temporary Python files + temp_files = [] + try: + for i in range(2): + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(f""" +class Class{i}: + def __init__(self): + self.id = {i} +""") + temp_files.append(f.name) + + # Import classes from both files + classes = self.importer.import_classes_from_paths(temp_files) + + # Should find 2 classes + assert len(classes) == 2 + + # Check that classes can be instantiated + class_list = list(classes) + instance0 = class_list[0]() + instance1 = class_list[1]() + + assert hasattr(instance0, 'id') + assert hasattr(instance1, 'id') + + finally: + for temp_file in temp_files: + os.unlink(temp_file) + + def test_import_classes_from_paths_with_filtering(self): + """Test importing classes with target subclass filtering.""" + # Create a temporary Python file with different types of classes + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class BaseClass: + pass + +class SubClass(BaseClass): + pass + +class UnrelatedClass: + pass +""") + temp_file_path = f.name + + try: + # First import to get the BaseClass + all_classes = self.importer.import_classes_from_paths([temp_file_path]) + base_class = None + for cls in all_classes: + if cls.__name__ == 'BaseClass': + base_class = cls + break + + assert base_class is not None + + # Import only subclasses of BaseClass + classes = self.importer.import_classes_from_paths( + [temp_file_path], + target_subclasses=[base_class] + ) + + # Should only find SubClass + assert len(classes) == 1 + class_list = list(classes) + assert class_list[0].__name__ == 'SubClass' + + finally: + os.unlink(temp_file_path) + + def test_clear_cache(self): + """Test clearing the module cache.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + pass +""") + temp_file_path = f.name + + try: + # Import module + module1 = self.importer.import_module_from_path(temp_file_path) + assert self.importer.get_cache_size() == 1 + + # Clear cache + self.importer.clear_cache() + assert self.importer.get_cache_size() == 0 + + # Import again (should be a different object) + module2 = self.importer.import_module_from_path(temp_file_path) + assert module1 is not module2 + + finally: + os.unlink(temp_file_path) + + def test_import_nonexistent_file(self): + """Test importing a non-existent file.""" + result = self.importer.import_module_from_path("/nonexistent/file.py") + assert result is None + + def test_import_invalid_python_file(self): + """Test importing a file with invalid Python syntax.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + def __init__(self): + # Invalid syntax + if True + pass +""") + temp_file_path = f.name + + try: + result = self.importer.import_module_from_path(temp_file_path) + assert result is None + + finally: + os.unlink(temp_file_path) \ No newline at end of file From b987d2ae45203342ea8cb1da9abf049ef6ec12a5 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Sun, 20 Jul 2025 23:23:03 +0800 Subject: [PATCH 35/42] Improve Error Handling in Module Importer (#18) --- py_spring_core/__init__.py | 2 +- py_spring_core/commons/module_importer.py | 10 +++------- py_spring_core/core/utils.py | 2 -- tests/test_module_importer.py | 10 +++++----- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 306667d..c114088 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -22,7 +22,7 @@ from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent -__version__ = "0.0.22" +__version__ = "0.0.23" __all__ = [ "PySpringApplication", diff --git a/py_spring_core/commons/module_importer.py b/py_spring_core/commons/module_importer.py index 44a6573..6b20108 100644 --- a/py_spring_core/commons/module_importer.py +++ b/py_spring_core/commons/module_importer.py @@ -59,7 +59,7 @@ def import_module_from_path(self, file_path: str) -> Optional[Any]: return module except Exception as error: logger.warning(f"[MODULE IMPORT] Failed to import {module_name}: {error}") - return None + raise error def extract_classes_from_module(self, module: Any) -> list[Type[object]]: """ @@ -84,8 +84,7 @@ def extract_classes_from_module(self, module: Any) -> list[Type[object]]: def import_classes_from_paths( self, file_paths: Iterable[str], - target_subclasses: Iterable[Type[object]] = [], - ignore_errors: bool = True + target_subclasses: Iterable[Type[object]] = [] ) -> set[Type[object]]: """ Import classes from multiple file paths with optional filtering. @@ -93,7 +92,6 @@ def import_classes_from_paths( Args: file_paths (Iterable[str]): The file paths of the modules to import. target_subclasses (Iterable[Type[object]], optional): Target subclasses to filter. Defaults to []. - ignore_errors (bool, optional): Whether to ignore import errors. Defaults to True. Returns: set[Type[object]]: Set of imported classes. @@ -103,9 +101,7 @@ def import_classes_from_paths( for file_path in file_paths: module = self.import_module_from_path(file_path) if module is None: - if not ignore_errors: - raise ImportError(f"Failed to import module from {file_path}") - continue + raise ImportError(f"Failed to import module from {file_path}") loaded_classes = self.extract_classes_from_module(module) all_loaded_classes.extend(loaded_classes) diff --git a/py_spring_core/core/utils.py b/py_spring_core/core/utils.py index 5421096..16da747 100644 --- a/py_spring_core/core/utils.py +++ b/py_spring_core/core/utils.py @@ -12,7 +12,6 @@ def dynamically_import_modules( module_paths: Iterable[str], - is_ignore_error: bool = True, target_subclasses: Iterable[Type[object]] = [], ) -> set[Type[object]]: """ @@ -29,7 +28,6 @@ def dynamically_import_modules( return _module_importer.import_classes_from_paths( file_paths=module_paths, target_subclasses=target_subclasses, - ignore_errors=is_ignore_error ) diff --git a/tests/test_module_importer.py b/tests/test_module_importer.py index eea1fd3..8e50ac7 100644 --- a/tests/test_module_importer.py +++ b/tests/test_module_importer.py @@ -204,8 +204,8 @@ class TestClass: def test_import_nonexistent_file(self): """Test importing a non-existent file.""" - result = self.importer.import_module_from_path("/nonexistent/file.py") - assert result is None + with pytest.raises(FileNotFoundError): + self.importer.import_module_from_path("/nonexistent/file.py") def test_import_invalid_python_file(self): """Test importing a file with invalid Python syntax.""" @@ -213,15 +213,15 @@ def test_import_invalid_python_file(self): f.write(""" class TestClass: def __init__(self): - # Invalid syntax + # Invalid syntax - missing colon if True pass """) temp_file_path = f.name try: - result = self.importer.import_module_from_path(temp_file_path) - assert result is None + with pytest.raises(SyntaxError): + self.importer.import_module_from_path(temp_file_path) finally: os.unlink(temp_file_path) \ No newline at end of file From 070d556cc3a133c39e8d3461363bdcc01db5fc1f Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Mon, 21 Jul 2025 15:56:08 +0800 Subject: [PATCH 36/42] Implement Graceful Shutdown Handler with Configurable Timeout (#19) --- py_spring_core/__init__.py | 8 +- .../templates.py | 1 + .../core/application/application_config.py | 15 + py_spring_core/core/application/commons.py | 4 +- .../context/application_context.py | 6 +- .../core/application/py_spring_application.py | 91 ++++- .../{ => bean_collection}/bean_collection.py | 0 .../entities/{ => component}/component.py | 0 .../{ => entity_provider}/entity_provider.py | 4 +- .../interfaces/graceful_shutdown_handler.py | 143 ++++++++ .../application_event_handler_registry.py | 2 +- .../event/application_event_publisher.py | 2 +- tests/test_application_config.py | 146 ++++++++ tests/test_application_context.py | 4 +- tests/test_bean_collection.py | 4 +- tests/test_component_features.py | 2 +- tests/test_entity_provider.py | 4 +- tests/test_framework_utils.py | 226 ++++++++++++ tests/test_graceful_shutdown_handler.py | 340 ++++++++++++++++++ 19 files changed, 965 insertions(+), 37 deletions(-) rename py_spring_core/core/entities/{ => bean_collection}/bean_collection.py (100%) rename py_spring_core/core/entities/{ => component}/component.py (100%) rename py_spring_core/core/entities/{ => entity_provider}/entity_provider.py (89%) create mode 100644 py_spring_core/core/interfaces/graceful_shutdown_handler.py create mode 100644 tests/test_application_config.py create mode 100644 tests/test_framework_utils.py create mode 100644 tests/test_graceful_shutdown_handler.py diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index c114088..76d0d00 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -1,6 +1,6 @@ from py_spring_core.core.application.py_spring_application import PySpringApplication -from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component, ComponentScope +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection +from py_spring_core.core.entities.component.component import Component, ComponentScope from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.controllers.route_mapping import ( DeleteMapping, @@ -9,7 +9,7 @@ PostMapping, PutMapping, ) -from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider from py_spring_core.core.entities.middlewares.middleware import Middleware from py_spring_core.core.entities.middlewares.middleware_registry import ( MiddlewareRegistry, @@ -22,7 +22,7 @@ from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent -__version__ = "0.0.23" +__version__ = "0.0.24" __all__ = [ "PySpringApplication", diff --git a/py_spring_core/commons/config_file_template_generator/templates.py b/py_spring_core/commons/config_file_template_generator/templates.py index 61a81be..6160d15 100644 --- a/py_spring_core/commons/config_file_template_generator/templates.py +++ b/py_spring_core/commons/config_file_template_generator/templates.py @@ -6,6 +6,7 @@ "properties_file_path": "./application-properties.json", "loguru_config": {"log_file_path": "./logs/app.log", "log_level": "DEBUG"}, "type_checking_mode": "strict", + "shutdown_config": {"timeout_seconds": 30.0, "enabled": True}, } app_properties_template: dict[str, Any] = {} diff --git a/py_spring_core/core/application/application_config.py b/py_spring_core/core/application/application_config.py index e86b00c..8816185 100644 --- a/py_spring_core/core/application/application_config.py +++ b/py_spring_core/core/application/application_config.py @@ -21,6 +21,19 @@ class ServerConfig(BaseModel): enabled: bool = Field(default=True) +class ShutdownConfig(BaseModel): + """ + Represents the configuration for graceful shutdown. + + Attributes: + timeout_seconds: The maximum time in seconds to wait for graceful shutdown before forcing termination. + enabled: A boolean flag indicating whether graceful shutdown timeout is enabled. + """ + + timeout_seconds: float = Field(default=30.0, description="Timeout in seconds for graceful shutdown") + enabled: bool = Field(default=True, description="Whether graceful shutdown timeout is enabled") + + class ApplicationConfig(BaseModel): """ Represents the configuration for the application. @@ -31,6 +44,7 @@ class ApplicationConfig(BaseModel): sqlalchemy_database_uri: The URI for the SQLAlchemy database connection. properties_file_path: The file path for the application properties. model_file_postfix_patterns: A list of file name patterns for model (for table creation) files. + shutdown_config: The configuration for graceful shutdown. """ model_config = ConfigDict(protected_namespaces=()) @@ -40,6 +54,7 @@ class ApplicationConfig(BaseModel): server_config: ServerConfig properties_file_path: str loguru_config: LoguruConfig + shutdown_config: ShutdownConfig = Field(default_factory=ShutdownConfig) class ApplicationConfigRepository(JsonConfigRepository[ApplicationConfig]): diff --git a/py_spring_core/core/application/commons.py b/py_spring_core/core/application/commons.py index eb0b11b..3d2ebab 100644 --- a/py_spring_core/core/application/commons.py +++ b/py_spring_core/core/application/commons.py @@ -1,5 +1,5 @@ -from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection +from py_spring_core.core.entities.component.component import Component from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.properties.properties import Properties diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index a51d541..4d43f0a 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -22,15 +22,15 @@ from py_spring_core.core.application.context.application_context_config import ( ApplicationContextConfig, ) -from py_spring_core.core.entities.bean_collection import ( +from py_spring_core.core.entities.bean_collection.bean_collection import ( BeanCollection, BeanConflictError, BeanView, InvalidBeanError, ) -from py_spring_core.core.entities.component import Component, ComponentScope +from py_spring_core.core.entities.component.component import Component, ComponentScope from py_spring_core.core.entities.controllers.rest_controller import RestController -from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.properties.properties_loader import _PropertiesLoader diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index 48f1850..cc39d1c 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Callable, Iterable, Type +from typing import Any, Callable, Iterable, Optional, Type import uvicorn from fastapi import APIRouter, FastAPI @@ -23,11 +23,12 @@ ApplicationContextConfig, ) from py_spring_core.core.application.loguru_config import LogFormat -from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component, ComponentLifeCycle +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection +from py_spring_core.core.entities.component.component import Component, ComponentLifeCycle from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.controllers.route_mapping import RouteMapping -from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider +from py_spring_core.core.entities.middlewares.middleware import Middleware from py_spring_core.core.entities.middlewares.middleware_registry import ( MiddlewareRegistry, ) @@ -35,11 +36,14 @@ from py_spring_core.core.interfaces.application_context_required import ( ApplicationContextRequired, ) +from py_spring_core.core.interfaces.graceful_shutdown_handler import GracefulShutdownHandler +from py_spring_core.core.interfaces.single_inheritance_required import SingleInheritanceRequired from py_spring_core.event.application_event_handler_registry import ( ApplicationEventHandlerRegistry, ) from py_spring_core.event.application_event_publisher import ApplicationEventPublisher +import py_spring_core.core.utils as framework_utils class PySpringApplication: """ @@ -102,6 +106,7 @@ def __init__( self.type_checking_service = TypeCheckingService( self.app_config.app_src_target_dir ) + self.shutdown_handler: Optional[GracefulShutdownHandler] = None def __configure_logging(self): """Applies the logging configuration using Loguru.""" @@ -232,31 +237,76 @@ def __init_controllers(self) -> None: self.fastapi.include_router(router) logger.debug(f"[CONTROLLER INIT] Controller {name} initialized") - def __init_middlewares(self) -> None: - logger.debug("[MIDDLEWARE INIT] Initialize middlewares...") - self_defined_registry_cls = MiddlewareRegistry.get_subclass() - if self_defined_registry_cls is None: - logger.debug("[MIDDLEWARE INIT] No self defined registry class found") - return + def _init_external_handler(self, base_class: Type[SingleInheritanceRequired]) -> Type[Any] | None: + """Initialize an external handler (middleware registry or graceful shutdown handler). + + Args: + base_class: The base class to get subclass from + handler_type: The type of handler for logging purposes + + Returns: + The initialized handler class or None if no handler is found + + Raises: + RuntimeError: If the handler has unimplemented abstract methods + """ + handler_type = base_class.__name__ + self_defined_handler_cls = base_class.get_subclass() + if self_defined_handler_cls is None: + logger.debug(f"[{handler_type} INIT] No self defined {handler_type.lower()} class found") + return None + + unimplemented_abstract_methods = framework_utils.get_unimplemented_abstract_methods(self_defined_handler_cls) + if len(unimplemented_abstract_methods) > 0: + error_message = f"[{handler_type} INIT] Self defined {handler_type.lower()} class: {self_defined_handler_cls.__name__} has unimplemented abstract methods: {unimplemented_abstract_methods}" + logger.error(error_message) + raise RuntimeError(error_message) + logger.debug( - f"[MIDDLEWARE INIT] Self defined registry class: {self_defined_registry_cls.__name__}" + f"[{handler_type} INIT] Self defined {handler_type.lower()} class: {self_defined_handler_cls.__name__}" ) logger.debug( - f"[MIDDLEWARE INIT] Inject dependencies for external object: {self_defined_registry_cls.__name__}" + f"[{handler_type} INIT] Inject dependencies for external object: {self_defined_handler_cls.__name__}" ) self.app_context.inject_dependencies_for_external_object( - self_defined_registry_cls + self_defined_handler_cls ) - registry = self_defined_registry_cls() + return self_defined_handler_cls - middleware_classes = registry.get_middleware_classes() + def __init_middlewares(self) -> None: + handler_type = MiddlewareRegistry.__name__ + logger.debug(f"[{handler_type} INIT] Initialize middlewares...") + registry_cls = self._init_external_handler(MiddlewareRegistry) + if registry_cls is None: + return + + registry: MiddlewareRegistry = registry_cls() + middleware_classes: list[Type[Middleware]] = registry.get_middleware_classes() for middleware_class in middleware_classes: logger.debug( - f"[MIDDLEWARE INIT] Inject dependencies for middleware: {middleware_class.__name__}" + f"[{handler_type} INIT] Inject dependencies for middleware: {middleware_class.__name__}" ) self.app_context.inject_dependencies_for_external_object(middleware_class) registry.apply_middlewares(self.fastapi) - logger.debug("[MIDDLEWARE INIT] Middlewares initialized") + logger.debug(f"[{handler_type} INIT] Middlewares initialized") + + def __init_graceful_shutdown(self) -> None: + handler_type = GracefulShutdownHandler.__name__ + logger.debug(f"[{handler_type} INIT] Initialize graceful shutdown...") + handler_cls: Optional[Type[GracefulShutdownHandler]] = self._init_external_handler(GracefulShutdownHandler) + if handler_cls is None: + return + + # Get shutdown configuration + shutdown_config = self.app_config.shutdown_config + logger.debug(f"[{handler_type} INIT] Shutdown timeout: {shutdown_config.timeout_seconds}s, enabled: {shutdown_config.enabled}") + + # Initialize handler with timeout configuration + self.shutdown_handler = handler_cls( + timeout_seconds=shutdown_config.timeout_seconds, + timeout_enabled=shutdown_config.enabled + ) + logger.debug(f"[{handler_type} INIT] Graceful shutdown initialized") def __configure_uvicorn_logging(self): """Configure Uvicorn to use Loguru instead of default logging.""" @@ -300,7 +350,14 @@ def run(self) -> None: self.__init_app() self.__init_controllers() self.__init_middlewares() + self.__init_graceful_shutdown() if self.app_config.server_config.enabled: self.__run_server() finally: + # Handle component lifecycle destruction self._handle_singleton_components_life_cycle(ComponentLifeCycle.Destruction) + # Handle graceful shutdown completion + if self.shutdown_handler: + self.shutdown_handler.complete_shutdown() + + \ No newline at end of file diff --git a/py_spring_core/core/entities/bean_collection.py b/py_spring_core/core/entities/bean_collection/bean_collection.py similarity index 100% rename from py_spring_core/core/entities/bean_collection.py rename to py_spring_core/core/entities/bean_collection/bean_collection.py diff --git a/py_spring_core/core/entities/component.py b/py_spring_core/core/entities/component/component.py similarity index 100% rename from py_spring_core/core/entities/component.py rename to py_spring_core/core/entities/component/component.py diff --git a/py_spring_core/core/entities/entity_provider.py b/py_spring_core/core/entities/entity_provider/entity_provider.py similarity index 89% rename from py_spring_core/core/entities/entity_provider.py rename to py_spring_core/core/entities/entity_provider/entity_provider.py index 79c21a8..bdaa29c 100644 --- a/py_spring_core/core/entities/entity_provider.py +++ b/py_spring_core/core/entities/entity_provider/entity_provider.py @@ -2,8 +2,8 @@ from typing import Any, Optional, Type from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection +from py_spring_core.core.entities.component.component import Component from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.properties.properties import Properties diff --git a/py_spring_core/core/interfaces/graceful_shutdown_handler.py b/py_spring_core/core/interfaces/graceful_shutdown_handler.py new file mode 100644 index 0000000..504026f --- /dev/null +++ b/py_spring_core/core/interfaces/graceful_shutdown_handler.py @@ -0,0 +1,143 @@ + + +from abc import ABC, abstractmethod +from enum import Enum, auto +import os +import signal +import threading +import time +from types import FrameType +from typing import Optional + +from loguru import logger + +from py_spring_core.core.interfaces.single_inheritance_required import SingleInheritanceRequired + +class ShutdownType(Enum): + MANUAL = auto() # e.g., Ctrl+C + SIGTERM = auto() # e.g., docker stop, systemctl stop + TIMEOUT = auto() # e.g., shutdown triggered by time constraint + ERROR = auto() # e.g., unrecoverable fault + UNKNOWN = auto() + + +class GracefulShutdownHandler(SingleInheritanceRequired, ABC): + """ + A mixin class that provides a method to handle graceful shutdown. + """ + + def __init__(self, timeout_seconds: float, timeout_enabled: bool) -> None: + self._shutdown_event = threading.Event() + self._shutdown_type: Optional[ShutdownType] = None + self._timeout_seconds = timeout_seconds + self._timeout_enabled = timeout_enabled + self._timeout_timer: Optional[threading.Timer] = None + self._shutdown_start_time: Optional[float] = None + + signal.signal(signal.SIGINT, self._handle_sigint) + signal.signal(signal.SIGTERM, self._handle_sigterm) + + def _handle_sigint(self, signum: int, frame: Optional[FrameType]) -> None: + try: + # Check if shutdown is already in progress to prevent duplicate execution + if self._shutdown_event.is_set(): + logger.debug("[Signal] SIGINT ignored - shutdown already in progress") + return + + logger.info("[Signal] SIGINT received") + self._shutdown_type = ShutdownType.MANUAL + self._shutdown_event.set() + self._start_shutdown_timer() + self.on_shutdown(ShutdownType.MANUAL) + except Exception as error: + self.on_error(error) + + def _handle_sigterm(self, signum: int, frame: Optional[FrameType]) -> None: + try: + # Check if shutdown is already in progress to prevent duplicate execution + if self._shutdown_event.is_set(): + logger.debug("[Signal] SIGTERM ignored - shutdown already in progress") + return + + logger.info("[Signal] SIGTERM received") + self._shutdown_type = ShutdownType.SIGTERM + self._shutdown_event.set() + self._start_shutdown_timer() + self.on_shutdown(ShutdownType.SIGTERM) + except Exception as error: + self.on_error(error) + + def _start_shutdown_timer(self) -> None: + """Start the shutdown timeout timer if enabled.""" + if not self._timeout_enabled: + return + + self._shutdown_start_time = time.time() + logger.info(f"[Shutdown Timer] Starting shutdown timer for {self._timeout_seconds} seconds") + + self._timeout_timer = threading.Timer(self._timeout_seconds, self._handle_timeout) + self._timeout_timer.daemon = True + self._timeout_timer.start() + + def _handle_timeout(self) -> None: + """Handle shutdown timeout.""" + try: + if not self._shutdown_event.is_set(): + return # Shutdown was not initiated, ignore timeout + + logger.info(f"[Shutdown Timer] Shutdown timeout reached after {self._timeout_seconds} seconds") + self._shutdown_type = ShutdownType.TIMEOUT + self.on_timeout() + except Exception as error: + self.on_error(error) + finally: + logger.critical(f"[Shutdown Timer] Timer exited with grace period of {self._timeout_seconds} seconds, exiting application") + os._exit(0) + + def complete_shutdown(self) -> None: + """Mark shutdown as complete and cancel timeout timer.""" + if self._timeout_timer and self._timeout_timer.is_alive(): + self._timeout_timer.cancel() + + if self._shutdown_start_time: + elapsed = time.time() - self._shutdown_start_time + logger.success(f"[Shutdown Timer] Shutdown completed successfully in {elapsed:.2f} seconds") + + def is_shutdown(self) -> bool: + return self._shutdown_event.is_set() + + def get_type(self) -> Optional[ShutdownType]: + return self._shutdown_type + + def get_timeout_seconds(self) -> float: + return self._timeout_seconds + + def is_timeout_enabled(self) -> bool: + return self._timeout_enabled + + def get_shutdown_elapsed_time(self) -> Optional[float]: + """Get the elapsed time since shutdown started.""" + if self._shutdown_start_time is None: + return None + return time.time() - self._shutdown_start_time + + @abstractmethod + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + """ + Handle shutdown. + """ + ... + + @abstractmethod + def on_timeout(self) -> None: + """ + Handle timeout. + """ + ... + + @abstractmethod + def on_error(self, error: Exception) -> None: + """ + Handle error. + """ + ... \ No newline at end of file diff --git a/py_spring_core/event/application_event_handler_registry.py b/py_spring_core/event/application_event_handler_registry.py index 7d5d0fb..14aeaee 100644 --- a/py_spring_core/event/application_event_handler_registry.py +++ b/py_spring_core/event/application_event_handler_registry.py @@ -4,7 +4,7 @@ from loguru import logger from pydantic import BaseModel -from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.component.component import Component from py_spring_core.core.interfaces.application_context_required import ( ApplicationContextRequired, ) diff --git a/py_spring_core/event/application_event_publisher.py b/py_spring_core/event/application_event_publisher.py index 3012d75..dccbe88 100644 --- a/py_spring_core/event/application_event_publisher.py +++ b/py_spring_core/event/application_event_publisher.py @@ -1,6 +1,6 @@ from typing import TypeVar -from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.component.component import Component from py_spring_core.event.application_event_handler_registry import ( ApplicationEvent, ApplicationEventHandlerRegistry, diff --git a/tests/test_application_config.py b/tests/test_application_config.py new file mode 100644 index 0000000..248f883 --- /dev/null +++ b/tests/test_application_config.py @@ -0,0 +1,146 @@ +import pytest +from pydantic import ValidationError + +from py_spring_core.core.application.application_config import ( + ApplicationConfig, + ShutdownConfig, + ServerConfig, + LoguruConfig, +) + + +class TestShutdownConfig: + """Test suite for the ShutdownConfig class.""" + + def test_shutdown_config_defaults(self): + """Test that ShutdownConfig has correct default values.""" + config = ShutdownConfig() + + assert config.timeout_seconds == 30.0 + assert config.enabled is True + + def test_shutdown_config_custom_values(self): + """Test ShutdownConfig with custom values.""" + config = ShutdownConfig(timeout_seconds=60.0, enabled=False) + + assert config.timeout_seconds == 60.0 + assert config.enabled is False + + def test_shutdown_config_validation(self): + """Test ShutdownConfig validation.""" + # Valid timeout values + config = ShutdownConfig(timeout_seconds=0.1) + assert config.timeout_seconds == 0.1 + + config = ShutdownConfig(timeout_seconds=1000.0) + assert config.timeout_seconds == 1000.0 + + def test_shutdown_config_type_validation(self): + """Test that ShutdownConfig validates types correctly.""" + # Should accept float or int for timeout_seconds + config = ShutdownConfig(timeout_seconds=30) + assert config.timeout_seconds == 30.0 + + config = ShutdownConfig(timeout_seconds=30.5) + assert config.timeout_seconds == 30.5 + + # Should accept boolean for enabled + config = ShutdownConfig(enabled=True) + assert config.enabled is True + + config = ShutdownConfig(enabled=False) + assert config.enabled is False + + def test_shutdown_config_serialization(self): + """Test that ShutdownConfig can be serialized/deserialized.""" + config = ShutdownConfig(timeout_seconds=45.0, enabled=True) + + # Test model dump + config_dict = config.model_dump() + expected = {"timeout_seconds": 45.0, "enabled": True} + assert config_dict == expected + + # Test reconstruction from dict + new_config = ShutdownConfig(**config_dict) + assert new_config.timeout_seconds == 45.0 + assert new_config.enabled is True + + +class TestApplicationConfigWithShutdown: + """Test suite for ApplicationConfig with shutdown configuration.""" + + def test_application_config_with_default_shutdown(self): + """Test that ApplicationConfig includes default shutdown config.""" + config = ApplicationConfig( + app_src_target_dir="./src", + server_config=ServerConfig(host="localhost", port=8000), + properties_file_path="./app.properties", + loguru_config=LoguruConfig(log_file_path="./logs/app.log") + ) + + # Should have default shutdown config + assert config.shutdown_config is not None + assert config.shutdown_config.timeout_seconds == 30.0 + assert config.shutdown_config.enabled is True + + def test_application_config_with_custom_shutdown(self): + """Test ApplicationConfig with custom shutdown configuration.""" + custom_shutdown = ShutdownConfig(timeout_seconds=60.0, enabled=False) + + config = ApplicationConfig( + app_src_target_dir="./src", + server_config=ServerConfig(host="localhost", port=8000), + properties_file_path="./app.properties", + loguru_config=LoguruConfig(log_file_path="./logs/app.log"), + shutdown_config=custom_shutdown + ) + + assert config.shutdown_config.timeout_seconds == 60.0 + assert config.shutdown_config.enabled is False + + def test_application_config_serialization_with_shutdown(self): + """Test ApplicationConfig serialization includes shutdown config.""" + config = ApplicationConfig( + app_src_target_dir="./src", + server_config=ServerConfig(host="localhost", port=8000), + properties_file_path="./app.properties", + loguru_config=LoguruConfig(log_file_path="./logs/app.log"), + shutdown_config=ShutdownConfig(timeout_seconds=45.0, enabled=True) + ) + + config_dict = config.model_dump() + + assert "shutdown_config" in config_dict + assert config_dict["shutdown_config"]["timeout_seconds"] == 45.0 + assert config_dict["shutdown_config"]["enabled"] is True + + def test_application_config_from_dict_with_shutdown(self): + """Test ApplicationConfig reconstruction from dict with shutdown config.""" + config_dict = { + "app_src_target_dir": "./src", + "server_config": {"host": "localhost", "port": 8000}, + "properties_file_path": "./app.properties", + "loguru_config": {"log_file_path": "./logs/app.log", "log_level": "INFO"}, + "shutdown_config": {"timeout_seconds": 25.0, "enabled": False} + } + + config = ApplicationConfig(**config_dict) + + assert config.app_src_target_dir == "./src" + assert config.shutdown_config.timeout_seconds == 25.0 + assert config.shutdown_config.enabled is False + + def test_application_config_without_shutdown_config_in_dict(self): + """Test that ApplicationConfig uses default shutdown when not provided.""" + config_dict = { + "app_src_target_dir": "./src", + "server_config": {"host": "localhost", "port": 8000}, + "properties_file_path": "./app.properties", + "loguru_config": {"log_file_path": "./logs/app.log", "log_level": "INFO"} + } + + config = ApplicationConfig(**config_dict) + + # Should use default shutdown config + assert config.shutdown_config.timeout_seconds == 30.0 + assert config.shutdown_config.enabled is True \ No newline at end of file diff --git a/tests/test_application_context.py b/tests/test_application_context.py index 52eaa7e..9da4569 100644 --- a/tests/test_application_context.py +++ b/tests/test_application_context.py @@ -5,8 +5,8 @@ ApplicationContext, ApplicationContextConfig, ) -from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection +from py_spring_core.core.entities.component.component import Component from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.properties.properties import Properties diff --git a/tests/test_bean_collection.py b/tests/test_bean_collection.py index 6a35c21..d8f5e9e 100644 --- a/tests/test_bean_collection.py +++ b/tests/test_bean_collection.py @@ -1,7 +1,7 @@ import pytest -from py_spring_core.core.entities.bean_collection import BeanCollection, BeanView -from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection, BeanView +from py_spring_core.core.entities.component.component import Component class TestBeanView: diff --git a/tests/test_component_features.py b/tests/test_component_features.py index 4221ac9..1aa26b1 100644 --- a/tests/test_component_features.py +++ b/tests/test_component_features.py @@ -7,7 +7,7 @@ ApplicationContext, ApplicationContextConfig, ) -from py_spring_core.core.entities.component import Component, ComponentScope +from py_spring_core.core.entities.component.component import Component, ComponentScope class TestComponentFeatures: diff --git a/tests/test_entity_provider.py b/tests/test_entity_provider.py index bda16b8..3279053 100644 --- a/tests/test_entity_provider.py +++ b/tests/test_entity_provider.py @@ -8,8 +8,8 @@ from py_spring_core.core.application.context.application_context_config import ( ApplicationContextConfig, ) -from py_spring_core.core.entities.component import Component -from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.component.component import Component +from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider class TestComponent(Component): ... diff --git a/tests/test_framework_utils.py b/tests/test_framework_utils.py new file mode 100644 index 0000000..8c83ce6 --- /dev/null +++ b/tests/test_framework_utils.py @@ -0,0 +1,226 @@ +import pytest +from abc import ABC, abstractmethod +from typing import Type, Any + +from py_spring_core.core.utils import get_unimplemented_abstract_methods + + +class TestFrameworkUtils: + """Test suite for framework utility functions.""" + + def test_get_unimplemented_abstract_methods_with_concrete_class(self): + """Test that fully implemented classes return empty set.""" + + class AbstractBase(ABC): + @abstractmethod + def method_a(self) -> None: + pass + + @abstractmethod + def method_b(self) -> str: + pass + + class ConcreteImpl(AbstractBase): + def method_a(self) -> None: + pass + + def method_b(self) -> str: + return "implemented" + + unimplemented = get_unimplemented_abstract_methods(ConcreteImpl) + assert unimplemented == set() + + def test_get_unimplemented_abstract_methods_with_partial_implementation(self): + """Test that partially implemented classes return missing methods.""" + + class AbstractBase(ABC): + @abstractmethod + def method_a(self) -> None: + pass + + @abstractmethod + def method_b(self) -> str: + pass + + @abstractmethod + def method_c(self) -> int: + pass + + class PartialImpl(AbstractBase): + def method_a(self) -> None: + pass + + # method_b and method_c are not implemented + + unimplemented = get_unimplemented_abstract_methods(PartialImpl) + assert unimplemented == {"method_b", "method_c"} + + def test_get_unimplemented_abstract_methods_with_no_implementation(self): + """Test that classes with no implementations return all abstract methods.""" + + class AbstractBase(ABC): + @abstractmethod + def method_a(self) -> None: + pass + + @abstractmethod + def method_b(self) -> str: + pass + + class NoImpl(AbstractBase): + # No methods implemented + pass + + unimplemented = get_unimplemented_abstract_methods(NoImpl) + assert unimplemented == {"method_a", "method_b"} + + def test_get_unimplemented_abstract_methods_with_multiple_inheritance(self): + """Test with multiple inheritance from abstract classes.""" + + class AbstractA(ABC): + @abstractmethod + def method_a(self) -> None: + pass + + class AbstractB(ABC): + @abstractmethod + def method_b(self) -> str: + pass + + class MultipleInheritance(AbstractA, AbstractB): + def method_a(self) -> None: + pass + # method_b is not implemented + + unimplemented = get_unimplemented_abstract_methods(MultipleInheritance) + assert unimplemented == {"method_b"} + + def test_get_unimplemented_abstract_methods_with_inheritance_chain(self): + """Test with inheritance chain where parent implements some methods.""" + + class AbstractBase(ABC): + @abstractmethod + def method_a(self) -> None: + pass + + @abstractmethod + def method_b(self) -> str: + pass + + @abstractmethod + def method_c(self) -> int: + pass + + class PartialParent(AbstractBase): + def method_a(self) -> None: + pass + # method_b and method_c still abstract + + class ChildImpl(PartialParent): + def method_b(self) -> str: + return "implemented" + # method_c still not implemented + + unimplemented = get_unimplemented_abstract_methods(ChildImpl) + assert unimplemented == {"method_c"} + + def test_get_unimplemented_abstract_methods_with_no_abstract_methods(self): + """Test with class that has no abstract methods.""" + + class NonAbstractBase(ABC): + def regular_method(self) -> None: + pass + + class RegularClass(NonAbstractBase): + def another_method(self) -> str: + return "normal" + + unimplemented = get_unimplemented_abstract_methods(RegularClass) + assert unimplemented == set() + + def test_get_unimplemented_abstract_methods_type_error_non_class(self): + """Test that function raises TypeError for non-class types.""" + + with pytest.raises(TypeError, match="Expected a class type"): + get_unimplemented_abstract_methods("not a class") # type: ignore + + with pytest.raises(TypeError, match="Expected a class type"): + get_unimplemented_abstract_methods(42) # type: ignore + + def test_get_unimplemented_abstract_methods_type_error_non_abc(self): + """Test that function raises TypeError for non-ABC classes.""" + + class RegularClass: + def some_method(self) -> None: + pass + + with pytest.raises(TypeError, match="Expected a subclass of abc.ABC"): + get_unimplemented_abstract_methods(RegularClass) + + def test_get_unimplemented_abstract_methods_with_property_abstracts(self): + """Test with abstract properties.""" + + class AbstractWithProperty(ABC): + @property + @abstractmethod + def abstract_property(self) -> str: + pass + + @abstractmethod + def abstract_method(self) -> None: + pass + + class PartialPropertyImpl(AbstractWithProperty): + @property + def abstract_property(self) -> str: + return "implemented" + # abstract_method not implemented + + unimplemented = get_unimplemented_abstract_methods(PartialPropertyImpl) + # Properties might be included in the abstract methods set, so we check that abstract_method is there + # and abstract_property is not (since it's implemented) + assert "abstract_method" in unimplemented + assert len([method for method in unimplemented if "abstract_property" not in method]) >= 1 + + def test_get_unimplemented_abstract_methods_with_staticmethod_classmethod(self): + """Test with abstract static and class methods.""" + + class AbstractWithMethods(ABC): + @staticmethod + @abstractmethod + def abstract_static() -> str: + pass + + @classmethod + @abstractmethod + def abstract_class(cls) -> str: + pass + + @abstractmethod + def abstract_instance(self) -> None: + pass + + class PartialMethodImpl(AbstractWithMethods): + @staticmethod + def abstract_static() -> str: + return "static implemented" + + # abstract_class and abstract_instance not implemented + + unimplemented = get_unimplemented_abstract_methods(PartialMethodImpl) + assert unimplemented == {"abstract_class", "abstract_instance"} + + def test_get_unimplemented_abstract_methods_real_world_example(self): + """Test with a real-world like example similar to GracefulShutdownHandler.""" + + from py_spring_core.core.interfaces.graceful_shutdown_handler import GracefulShutdownHandler + + class IncompleteShutdownHandler(GracefulShutdownHandler): + def on_shutdown(self, shutdown_type) -> None: + pass + # on_timeout and on_error not implemented + + unimplemented = get_unimplemented_abstract_methods(IncompleteShutdownHandler) + assert "on_timeout" in unimplemented + assert "on_error" in unimplemented + assert "on_shutdown" not in unimplemented \ No newline at end of file diff --git a/tests/test_graceful_shutdown_handler.py b/tests/test_graceful_shutdown_handler.py new file mode 100644 index 0000000..76fb29e --- /dev/null +++ b/tests/test_graceful_shutdown_handler.py @@ -0,0 +1,340 @@ +import signal +import threading +import time +from typing import Optional +from unittest.mock import MagicMock, patch +import pytest + +from py_spring_core.core.interfaces.graceful_shutdown_handler import ( + GracefulShutdownHandler, + ShutdownType, +) + + +class TestGracefulShutdownHandler: + """Test suite for the graceful shutdown handler functionality.""" + + def setup_method(self): + """Reset any signal handlers before each test.""" + signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGTERM, signal.SIG_DFL) + + def test_graceful_shutdown_handler_interface(self): + """Test that GracefulShutdownHandler enforces abstract methods.""" + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + # Should not be able to instantiate abstract class directly + GracefulShutdownHandler(timeout_seconds=30.0, timeout_enabled=True) # type: ignore + + def test_concrete_implementation_creation(self): + """Test that a concrete implementation can be created successfully.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + # Should be able to create concrete implementation + handler = TestShutdownHandler(timeout_seconds=5.0, timeout_enabled=True) + assert handler.get_timeout_seconds() == 5.0 + assert handler.is_timeout_enabled() is True + assert handler.is_shutdown() is False + assert handler.get_type() is None + + def test_sigint_handling(self): + """Test SIGINT signal handling.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=1.0, timeout_enabled=False) + + # Simulate SIGINT + handler._handle_sigint(signal.SIGINT, None) + + assert handler.is_shutdown() is True + assert handler.get_type() == ShutdownType.MANUAL + assert len(handler.shutdown_calls) == 1 + assert handler.shutdown_calls[0] == ShutdownType.MANUAL + + def test_sigterm_handling(self): + """Test SIGTERM signal handling.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=1.0, timeout_enabled=False) + + # Simulate SIGTERM + handler._handle_sigterm(signal.SIGTERM, None) + + assert handler.is_shutdown() is True + assert handler.get_type() == ShutdownType.SIGTERM + assert len(handler.shutdown_calls) == 1 + assert handler.shutdown_calls[0] == ShutdownType.SIGTERM + + def test_duplicate_signal_handling(self): + """Test that duplicate signals are ignored.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=1.0, timeout_enabled=False) + + # First signal + handler._handle_sigint(signal.SIGINT, None) + assert len(handler.shutdown_calls) == 1 + + # Second signal should be ignored + handler._handle_sigint(signal.SIGINT, None) + assert len(handler.shutdown_calls) == 1 # Still only one call + + @patch('os._exit') + def test_timeout_functionality(self, mock_exit): + """Test shutdown timeout functionality.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + # Test with timeout enabled + handler = TestShutdownHandler(timeout_seconds=0.1, timeout_enabled=True) + + # Trigger shutdown + handler._handle_sigint(signal.SIGINT, None) + + # Wait for timeout to trigger + time.sleep(0.2) + + # Should have called timeout and os._exit + assert len(handler.timeout_calls) == 1 + mock_exit.assert_called_once_with(0) + + def test_timeout_disabled(self): + """Test that timeout doesn't trigger when disabled.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + # Test with timeout disabled + handler = TestShutdownHandler(timeout_seconds=0.1, timeout_enabled=False) + + # Trigger shutdown + handler._handle_sigint(signal.SIGINT, None) + + # Wait for potential timeout + time.sleep(0.2) + + # Should not have called timeout + assert len(handler.timeout_calls) == 0 + + def test_complete_shutdown(self): + """Test shutdown completion functionality.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=10.0, timeout_enabled=True) + + # Trigger shutdown + handler._handle_sigint(signal.SIGINT, None) + + # Complete shutdown before timeout + handler.complete_shutdown() + + # Wait to ensure timeout doesn't trigger + time.sleep(0.1) + + # Should not have called timeout since shutdown was completed + assert len(handler.timeout_calls) == 0 + + def test_error_handling_in_signals(self): + """Test error handling in signal handlers.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + raise RuntimeError("Test error in shutdown") + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=1.0, timeout_enabled=False) + + # Signal should trigger error handling + handler._handle_sigint(signal.SIGINT, None) + + assert len(handler.error_calls) == 1 + assert isinstance(handler.error_calls[0], RuntimeError) + + def test_shutdown_elapsed_time(self): + """Test shutdown elapsed time tracking.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + pass + + def on_timeout(self) -> None: + pass + + def on_error(self, error: Exception) -> None: + pass + + # Use a longer timeout to ensure timer starts + handler = TestShutdownHandler(timeout_seconds=30.0, timeout_enabled=True) + + # Before shutdown + assert handler.get_shutdown_elapsed_time() is None + + # Trigger shutdown - this should start the timer and set _shutdown_start_time + handler._handle_sigint(signal.SIGINT, None) + + # Small delay + time.sleep(0.1) + + # Should have elapsed time + elapsed = handler.get_shutdown_elapsed_time() + assert elapsed is not None + assert elapsed >= 0.1 + + def test_shutdown_types_enum(self): + """Test ShutdownType enum values.""" + assert ShutdownType.MANUAL is not None + assert ShutdownType.SIGTERM is not None + assert ShutdownType.TIMEOUT is not None + assert ShutdownType.ERROR is not None + assert ShutdownType.UNKNOWN is not None + + # Ensure all types are unique + types = [ShutdownType.MANUAL, ShutdownType.SIGTERM, ShutdownType.TIMEOUT, + ShutdownType.ERROR, ShutdownType.UNKNOWN] + assert len(set(types)) == len(types) + + @patch('os._exit') + def test_timeout_force_exit(self, mock_exit): + """Test that timeout eventually forces exit.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=0.1, timeout_enabled=True) + + # Trigger shutdown + handler._handle_sigint(signal.SIGINT, None) + + # Wait for timeout to trigger + time.sleep(0.2) + + # Should have called os._exit + mock_exit.assert_called_once_with(0) \ No newline at end of file From e61c6876d8c13fa3779ea0dbd06d215a426bbc04 Mon Sep 17 00:00:00 2001 From: William Chen Date: Mon, 21 Jul 2025 16:15:05 +0800 Subject: [PATCH 37/42] Add GracefulShutdownHandler to module exports in __init__.py --- py_spring_core/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 76d0d00..16825f6 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -18,6 +18,7 @@ from py_spring_core.core.interfaces.application_context_required import ( ApplicationContextRequired, ) +from py_spring_core.core.interfaces.graceful_shutdown_handler import GracefulShutdownHandler from py_spring_core.event.application_event_handler_registry import EventListener from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent @@ -43,4 +44,5 @@ "EventListener", "Middleware", "MiddlewareRegistry", + "GracefulShutdownHandler", ] From 81437195f142ebfb43dd6c3e4a9ffc4f89ff5aac Mon Sep 17 00:00:00 2001 From: William Chen Date: Mon, 21 Jul 2025 16:16:01 +0800 Subject: [PATCH 38/42] Enhance GracefulShutdownHandler import by adding ShutdownType to module exports in __init__.py --- py_spring_core/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 16825f6..ede435a 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -18,7 +18,7 @@ from py_spring_core.core.interfaces.application_context_required import ( ApplicationContextRequired, ) -from py_spring_core.core.interfaces.graceful_shutdown_handler import GracefulShutdownHandler +from py_spring_core.core.interfaces.graceful_shutdown_handler import GracefulShutdownHandler, ShutdownType from py_spring_core.event.application_event_handler_registry import EventListener from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent @@ -45,4 +45,5 @@ "Middleware", "MiddlewareRegistry", "GracefulShutdownHandler", + "ShutdownType", ] From 9854de00fa2eba669fe2184b061dc2c7fe5e2769 Mon Sep 17 00:00:00 2001 From: William Chen <86595028+NFUChen@users.noreply.github.com> Date: Tue, 22 Jul 2025 20:52:33 +0800 Subject: [PATCH 39/42] Feature: Middleware Enhancement (#20) --- py_spring_core/__init__.py | 7 +- .../core/application/py_spring_application.py | 20 +- .../middlewares/middleware_registry.py | 164 ++++++- py_spring_core/exception_handler/decorator.py | 18 + .../exception_handler_registry.py | 25 + tests/test_middleware.py | 436 ++++++++++++++---- 6 files changed, 559 insertions(+), 111 deletions(-) create mode 100644 py_spring_core/exception_handler/decorator.py create mode 100644 py_spring_core/exception_handler/exception_handler_registry.py diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index ede435a..412ce0a 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -12,7 +12,7 @@ from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider from py_spring_core.core.entities.middlewares.middleware import Middleware from py_spring_core.core.entities.middlewares.middleware_registry import ( - MiddlewareRegistry, + MiddlewareRegistry, MiddlewareConfiguration ) from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.interfaces.application_context_required import ( @@ -44,6 +44,7 @@ "EventListener", "Middleware", "MiddlewareRegistry", + "MiddlewareConfiguration", "GracefulShutdownHandler", - "ShutdownType", -] + "ShutdownType" +] \ No newline at end of file diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index cc39d1c..cb1dd64 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Callable, Iterable, Optional, Type +from typing import Any, Callable, Iterable, Optional, Type, TypeVar import uvicorn from fastapi import APIRouter, FastAPI @@ -30,6 +30,7 @@ from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider from py_spring_core.core.entities.middlewares.middleware import Middleware from py_spring_core.core.entities.middlewares.middleware_registry import ( + MiddlewareConfiguration, MiddlewareRegistry, ) from py_spring_core.core.entities.properties.properties import Properties @@ -45,6 +46,8 @@ import py_spring_core.core.utils as framework_utils + +SingleInheritanceRequiredT = TypeVar("SingleInheritanceRequiredT", bound=SingleInheritanceRequired) class PySpringApplication: """ The PySpringApplication class is the main entry point for the PySpring application. @@ -237,7 +240,7 @@ def __init_controllers(self) -> None: self.fastapi.include_router(router) logger.debug(f"[CONTROLLER INIT] Controller {name} initialized") - def _init_external_handler(self, base_class: Type[SingleInheritanceRequired]) -> Type[Any] | None: + def _init_external_handler(self, base_class: Type[SingleInheritanceRequiredT]) -> Type[SingleInheritanceRequiredT] | None: """Initialize an external handler (middleware registry or graceful shutdown handler). Args: @@ -276,12 +279,15 @@ def _init_external_handler(self, base_class: Type[SingleInheritanceRequired]) -> def __init_middlewares(self) -> None: handler_type = MiddlewareRegistry.__name__ logger.debug(f"[{handler_type} INIT] Initialize middlewares...") - registry_cls = self._init_external_handler(MiddlewareRegistry) - if registry_cls is None: + middeware_configuration_cls = self._init_external_handler(MiddlewareConfiguration) + if middeware_configuration_cls is None: return - - registry: MiddlewareRegistry = registry_cls() + registry = MiddlewareRegistry() + logger.info(f"[{handler_type} INIT] Setup middlewares for registry: {registry.__class__.__name__}") + middeware_configuration_cls().configure_middlewares(registry) + logger.info(f"[{handler_type} INIT] Middlewares setup for registry: {registry.__class__.__name__} completed") middleware_classes: list[Type[Middleware]] = registry.get_middleware_classes() + logger.info(f"[{handler_type} INIT] Middleware classes: {', '.join([middleware_class.__name__ for middleware_class in middleware_classes])}") for middleware_class in middleware_classes: logger.debug( f"[{handler_type} INIT] Inject dependencies for middleware: {middleware_class.__name__}" @@ -305,7 +311,7 @@ def __init_graceful_shutdown(self) -> None: self.shutdown_handler = handler_cls( timeout_seconds=shutdown_config.timeout_seconds, timeout_enabled=shutdown_config.enabled - ) + ) # type: ignore logger.debug(f"[{handler_type} INIT] Graceful shutdown initialized") def __configure_uvicorn_logging(self): diff --git a/py_spring_core/core/entities/middlewares/middleware_registry.py b/py_spring_core/core/entities/middlewares/middleware_registry.py index 348ba00..6fd7d93 100644 --- a/py_spring_core/core/entities/middlewares/middleware_registry.py +++ b/py_spring_core/core/entities/middlewares/middleware_registry.py @@ -9,7 +9,7 @@ ) -class MiddlewareRegistry(SingleInheritanceRequired["MiddlewareRegistry"], ABC): +class MiddlewareRegistry: """ Middleware registry for managing all middlewares @@ -31,20 +31,153 @@ class MiddlewareRegistry(SingleInheritanceRequired["MiddlewareRegistry"], ABC): Response: route → MiddlewareA → MiddlewareB This stacking behavior ensures that middlewares are executed in a predictable and controllable order. """ - - @abstractmethod + + def __init__(self): + """ + Initialize the middleware registry. + """ + self._middlewares: list[Type[Middleware]] = [] + + def add_middleware(self, middleware_class: Type[Middleware]) -> None: + """ + Add middleware to the end of the list. + + Args: + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + self._middlewares.append(middleware_class) + + def add_at_index(self, index: int, middleware_class: Type[Middleware]) -> None: + """ + Insert middleware at a specific index position. + + Args: + index: The position to insert at (0-based) + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered or index is invalid + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + if index < 0 or index > len(self._middlewares): + raise ValueError(f"Index {index} is out of range (0-{len(self._middlewares)})") + self._middlewares.insert(index, middleware_class) + + def add_before(self, target_middleware: Type[Middleware], middleware_class: Type[Middleware]) -> None: + """ + Insert middleware before the target middleware. + + Args: + target_middleware: The middleware to insert before + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered or target not found + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + if target_middleware not in self._middlewares: + raise ValueError(f"Target middleware {target_middleware.__name__} not found") + index = self._middlewares.index(target_middleware) + self._middlewares.insert(index, middleware_class) + + def add_after(self, target_middleware: Type[Middleware], middleware_class: Type[Middleware]) -> None: + """ + Insert middleware after the target middleware. + + Args: + target_middleware: The middleware to insert after + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered or target not found + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + if target_middleware not in self._middlewares: + raise ValueError(f"Target middleware {target_middleware.__name__} not found") + index = self._middlewares.index(target_middleware) + self._middlewares.insert(index + 1, middleware_class) + + def remove_middleware(self, middleware_class: Type[Middleware]) -> None: + """ + Remove a middleware from the registry. + + Args: + middleware_class: The middleware class to remove + + Raises: + ValueError: If middleware is not found + """ + if middleware_class not in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} not found") + self._middlewares.remove(middleware_class) + + def clear_middlewares(self) -> None: + """Remove all middlewares from the registry.""" + self._middlewares.clear() + + def has_middleware(self, middleware_class: Type[Middleware]) -> bool: + """ + Check if a middleware is registered. + + Args: + middleware_class: The middleware class to check + + Returns: + bool: True if middleware is registered, False otherwise + """ + return middleware_class in self._middlewares + + def get_middleware_count(self) -> int: + """ + Get the number of registered middlewares. + + Returns: + int: Number of registered middlewares + """ + return len(self._middlewares) + + def get_middleware_index(self, middleware_class: Type[Middleware]) -> int: + """ + Get the index of a middleware in the registry. + + Args: + middleware_class: The middleware class to find + + Returns: + int: The index of the middleware + + Raises: + ValueError: If middleware is not found + """ + if middleware_class not in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} not found") + return self._middlewares.index(middleware_class) + + + def get_middleware_classes(self) -> list[Type[Middleware]]: """ - Get all registered middleware classes - + Get all registered middleware classes. + Returns: - List[Type[Middleware]]: List of middleware classes + List[Type[Middleware]]: List of middleware classes in registration order """ - pass - + return self._middlewares.copy() + def apply_middlewares(self, app: FastAPI) -> FastAPI: """ - Apply middlewares to FastAPI application + Apply middlewares to FastAPI application. + + Iterates through all registered middlewares and applies them to the FastAPI + application instance in the order they were registered. Args: app: FastAPI application instance @@ -55,3 +188,16 @@ def apply_middlewares(self, app: FastAPI) -> FastAPI: for middleware_class in self.get_middleware_classes(): app.add_middleware(middleware_class) return app + + + +class MiddlewareConfiguration(SingleInheritanceRequired["MiddlewareConfiguration"]): + """ + Middleware configuration for managing middleware registration and execution order. + """ + + def configure_middlewares(self, registry: MiddlewareRegistry) -> None: + """ + Setup middlewares for the registry. + """ + pass \ No newline at end of file diff --git a/py_spring_core/exception_handler/decorator.py b/py_spring_core/exception_handler/decorator.py new file mode 100644 index 0000000..3d60b4e --- /dev/null +++ b/py_spring_core/exception_handler/decorator.py @@ -0,0 +1,18 @@ + + +from functools import wraps +from typing import Any, Callable, Type + + +from py_spring_core.exception_handler.exception_handler_registry import ExceptionHandlerRegistry + + +def ExceptionHandler(exception_cls: Type[Exception]) -> Callable[[Callable[[Exception], Any]], Callable]: + def decorator(func: Callable[[Exception], Any]) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + return func(*args, **kwargs) + ExceptionHandlerRegistry.register(exception_cls, wrapper) + return wrapper + + return decorator \ No newline at end of file diff --git a/py_spring_core/exception_handler/exception_handler_registry.py b/py_spring_core/exception_handler/exception_handler_registry.py new file mode 100644 index 0000000..2861975 --- /dev/null +++ b/py_spring_core/exception_handler/exception_handler_registry.py @@ -0,0 +1,25 @@ + + +from typing import Any, Callable, Type, TypeVar + +from loguru import logger + +E = TypeVar('E', bound=Exception) + +class ExceptionHandlerRegistry: + _handlers: dict[str, Callable[[Any], Any]] = {} + + @classmethod + def register(cls, exception_cls: Type[E], handler: Callable[[E], Any]): + key = exception_cls.__name__ + logger.debug(f"Registering exception handler for {key}: {handler.__name__}") + if key in cls._handlers: + error_message = f"Exception handler for {exception_cls} already registered" + logger.error(error_message) + raise RuntimeError(error_message) + + cls._handlers[exception_cls.__name__] = handler + + @classmethod + def get_handler(cls, exception_cls: Type[E]) -> Callable[[E], Any]: + return cls._handlers[exception_cls.__name__] \ No newline at end of file diff --git a/tests/test_middleware.py b/tests/test_middleware.py index e692bcc..979a0d1 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -8,6 +8,7 @@ from py_spring_core.core.entities.middlewares.middleware import Middleware from py_spring_core.core.entities.middlewares.middleware_registry import ( MiddlewareRegistry, + MiddlewareConfiguration, ) @@ -148,7 +149,11 @@ def test_should_skip_default_returns_false(self, mock_request): 1. The default implementation of should_skip returns False 2. This allows the middleware to process all requests by default """ - middleware = Middleware(app=Mock()) + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = TestMiddleware(app=Mock()) result = middleware.should_skip(mock_request) assert result is False @@ -164,6 +169,9 @@ class SkippingMiddleware(Middleware): def should_skip(self, request: Request) -> bool: return request.method == "GET" + async def process_request(self, request: Request) -> Response | None: + return None + middleware = SkippingMiddleware(app=Mock()) result = middleware.should_skip(mock_request) assert result is True @@ -251,6 +259,9 @@ def should_skip(self, request: Request) -> bool: received_request = request return False + async def process_request(self, request: Request) -> Response | None: + return None + middleware = TestMiddleware(app=Mock()) middleware.should_skip(mock_request) @@ -350,7 +361,11 @@ def test_should_skip_method_signature(self): 2. It takes a Request parameter 3. It returns a boolean value """ - middleware = Middleware(app=Mock()) + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = TestMiddleware(app=Mock()) # Check that should_skip is a method assert hasattr(middleware, 'should_skip') @@ -369,44 +384,237 @@ def test_should_skip_method_signature(self): class TestMiddlewareRegistry: - """Test suite for the MiddlewareRegistry abstract class.""" + """Test suite for the MiddlewareRegistry concrete class.""" @pytest.fixture def fastapi_app(self): """Fixture that provides a fresh FastAPI application instance.""" return FastAPI() - def test_middleware_registry_is_abstract(self): + @pytest.fixture + def registry(self): + """Fixture that provides a fresh MiddlewareRegistry instance.""" + return MiddlewareRegistry() + + @pytest.fixture + def test_middleware_1(self): + """Fixture that provides a test middleware class.""" + class TestMiddleware1(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + return TestMiddleware1 + + @pytest.fixture + def test_middleware_2(self): + """Fixture that provides another test middleware class.""" + class TestMiddleware2(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + return TestMiddleware2 + + def test_middleware_registry_instantiation(self): """ - Test that MiddlewareRegistry class is abstract and cannot be instantiated directly. + Test that MiddlewareRegistry can be instantiated directly. This test verifies that: - 1. MiddlewareRegistry is an abstract base class - 2. Attempting to instantiate it directly raises an error + 1. MiddlewareRegistry is a concrete class + 2. It can be instantiated without errors + 3. Initial state is correct """ - # This test verifies that MiddlewareRegistry is abstract - # We can't test direct instantiation because it's abstract - # Instead, we test that it has the abstract method - assert hasattr(MiddlewareRegistry, "get_middleware_classes") - assert MiddlewareRegistry.get_middleware_classes.__isabstractmethod__ + registry = MiddlewareRegistry() + assert isinstance(registry, MiddlewareRegistry) + assert registry.get_middleware_count() == 0 + assert registry.get_middleware_classes() == [] - def test_get_middleware_classes_is_abstract(self): + def test_add_middleware(self, registry, test_middleware_1): """ - Test that get_middleware_classes method is abstract and must be implemented. + Test adding middleware to the registry. This test verifies that: - 1. get_middleware_classes is an abstract method - 2. Subclasses must implement this method + 1. Middleware can be added successfully + 2. Middleware count increases + 3. Middleware appears in the classes list """ + registry.add_middleware(test_middleware_1) + + assert registry.get_middleware_count() == 1 + assert registry.has_middleware(test_middleware_1) + assert test_middleware_1 in registry.get_middleware_classes() - # Create a concrete subclass without implementing get_middleware_classes - class ConcreteRegistry(MiddlewareRegistry): # type: ignore[abstract] - pass + def test_add_duplicate_middleware_raises_error(self, registry, test_middleware_1): + """ + Test that adding duplicate middleware raises an error. + + This test verifies that: + 1. Adding the same middleware twice raises ValueError + 2. The error message is descriptive + 3. The registry state remains unchanged + """ + registry.add_middleware(test_middleware_1) + + with pytest.raises(ValueError, match="Middleware TestMiddleware1 is already registered"): + registry.add_middleware(test_middleware_1) + + # Verify state hasn't changed + assert registry.get_middleware_count() == 1 + + def test_add_at_index(self, registry, test_middleware_1, test_middleware_2): + """ + Test inserting middleware at specific index. - with pytest.raises(TypeError): - ConcreteRegistry() # type: ignore[abstract] + This test verifies that: + 1. Middleware can be inserted at specific positions + 2. Order is maintained correctly + 3. Index bounds are respected + """ + registry.add_middleware(test_middleware_1) + registry.add_at_index(0, test_middleware_2) + + classes = registry.get_middleware_classes() + assert classes[0] == test_middleware_2 + assert classes[1] == test_middleware_1 + + def test_add_at_invalid_index_raises_error(self, registry, test_middleware_1): + """ + Test that adding at invalid index raises an error. + + This test verifies that: + 1. Invalid indices raise ValueError + 2. Error message includes valid range + """ + with pytest.raises(ValueError, match="Index -1 is out of range"): + registry.add_at_index(-1, test_middleware_1) + + with pytest.raises(ValueError, match="Index 1 is out of range"): + registry.add_at_index(1, test_middleware_1) + + def test_add_before(self, registry, test_middleware_1, test_middleware_2): + """ + Test inserting middleware before another middleware. + + This test verifies that: + 1. Middleware can be inserted before target middleware + 2. Order is correct after insertion + """ + registry.add_middleware(test_middleware_1) + registry.add_before(test_middleware_1, test_middleware_2) + + classes = registry.get_middleware_classes() + assert classes[0] == test_middleware_2 + assert classes[1] == test_middleware_1 + + def test_add_before_nonexistent_target_raises_error(self, registry, test_middleware_1, test_middleware_2): + """ + Test that adding before nonexistent target raises error. + + This test verifies that: + 1. Adding before non-registered middleware raises ValueError + 2. Error message is descriptive + """ + with pytest.raises(ValueError, match="Target middleware TestMiddleware1 not found"): + registry.add_before(test_middleware_1, test_middleware_2) + + def test_add_after(self, registry, test_middleware_1, test_middleware_2): + """ + Test inserting middleware after another middleware. + + This test verifies that: + 1. Middleware can be inserted after target middleware + 2. Order is correct after insertion + """ + registry.add_middleware(test_middleware_1) + registry.add_after(test_middleware_1, test_middleware_2) + + classes = registry.get_middleware_classes() + assert classes[0] == test_middleware_1 + assert classes[1] == test_middleware_2 + + def test_remove_middleware(self, registry, test_middleware_1): + """ + Test removing middleware from the registry. + + This test verifies that: + 1. Middleware can be removed successfully + 2. Middleware count decreases + 3. Middleware no longer appears in classes list + """ + registry.add_middleware(test_middleware_1) + registry.remove_middleware(test_middleware_1) + + assert registry.get_middleware_count() == 0 + assert not registry.has_middleware(test_middleware_1) + assert test_middleware_1 not in registry.get_middleware_classes() + + def test_remove_nonexistent_middleware_raises_error(self, registry, test_middleware_1): + """ + Test that removing nonexistent middleware raises error. + + This test verifies that: + 1. Removing non-registered middleware raises ValueError + 2. Error message is descriptive + """ + with pytest.raises(ValueError, match="Middleware TestMiddleware1 not found"): + registry.remove_middleware(test_middleware_1) + + def test_clear_middlewares(self, registry, test_middleware_1, test_middleware_2): + """ + Test clearing all middlewares from the registry. + + This test verifies that: + 1. All middlewares are removed + 2. Registry returns to initial state + """ + registry.add_middleware(test_middleware_1) + registry.add_middleware(test_middleware_2) + + registry.clear_middlewares() + + assert registry.get_middleware_count() == 0 + assert registry.get_middleware_classes() == [] - def test_apply_middlewares_adds_middleware_to_app(self, fastapi_app): + def test_get_middleware_index(self, registry, test_middleware_1, test_middleware_2): + """ + Test getting the index of a middleware. + + This test verifies that: + 1. Index of registered middleware is returned correctly + 2. Index reflects the actual position in the list + """ + registry.add_middleware(test_middleware_1) + registry.add_middleware(test_middleware_2) + + assert registry.get_middleware_index(test_middleware_1) == 0 + assert registry.get_middleware_index(test_middleware_2) == 1 + + def test_get_middleware_index_nonexistent_raises_error(self, registry, test_middleware_1): + """ + Test that getting index of nonexistent middleware raises error. + + This test verifies that: + 1. Getting index of non-registered middleware raises ValueError + 2. Error message is descriptive + """ + with pytest.raises(ValueError, match="Middleware TestMiddleware1 not found"): + registry.get_middleware_index(test_middleware_1) + + def test_get_middleware_classes_returns_copy(self, registry, test_middleware_1): + """ + Test that get_middleware_classes returns a copy. + + This test verifies that: + 1. Modifying returned list doesn't affect internal state + 2. A copy is returned, not the original list + """ + registry.add_middleware(test_middleware_1) + + classes = registry.get_middleware_classes() + classes.clear() + + # Original registry should be unchanged + assert registry.get_middleware_count() == 1 + assert registry.has_middleware(test_middleware_1) + + def test_apply_middlewares_adds_middleware_to_app(self, registry, fastapi_app, test_middleware_1, test_middleware_2): """ Test that apply_middlewares correctly adds middleware classes to FastAPI app. @@ -415,33 +623,22 @@ def test_apply_middlewares_adds_middleware_to_app(self, fastapi_app): 2. The add_middleware method is called for each middleware class 3. The app is returned unchanged """ - - class TestMiddleware1(Middleware): - async def process_request(self, request: Request) -> Response | None: - return None - - class TestMiddleware2(Middleware): - async def process_request(self, request: Request) -> Response | None: - return None - - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [TestMiddleware1, TestMiddleware2] + registry.add_middleware(test_middleware_1) + registry.add_middleware(test_middleware_2) # Mock the add_middleware method with patch.object(fastapi_app, "add_middleware") as mock_add_middleware: - registry = TestRegistry() result = registry.apply_middlewares(fastapi_app) # Verify add_middleware was called for each middleware class assert mock_add_middleware.call_count == 2 - mock_add_middleware.assert_any_call(TestMiddleware1) - mock_add_middleware.assert_any_call(TestMiddleware2) + mock_add_middleware.assert_any_call(test_middleware_1) + mock_add_middleware.assert_any_call(test_middleware_2) # Verify the app is returned assert result == fastapi_app - def test_apply_middlewares_with_empty_list(self, fastapi_app): + def test_apply_middlewares_with_empty_list(self, registry, fastapi_app): """ Test that apply_middlewares handles empty middleware list correctly. @@ -450,13 +647,7 @@ def test_apply_middlewares_with_empty_list(self, fastapi_app): 2. The app is returned unchanged 3. No errors occur with empty middleware list """ - - class EmptyRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [] - with patch.object(fastapi_app, "add_middleware") as mock_add_middleware: - registry = EmptyRegistry() result = registry.apply_middlewares(fastapi_app) # Verify add_middleware was not called @@ -465,7 +656,7 @@ def get_middleware_classes(self) -> list[type[Middleware]]: # Verify the app is returned assert result == fastapi_app - def test_apply_middlewares_preserves_app_state(self, fastapi_app): + def test_apply_middlewares_preserves_app_state(self, registry, fastapi_app, test_middleware_1): """ Test that apply_middlewares preserves the FastAPI app state. @@ -473,19 +664,11 @@ def test_apply_middlewares_preserves_app_state(self, fastapi_app): 1. The original app object is returned (same reference) 2. No app properties are modified during middleware application """ - - class TestMiddleware(Middleware): - async def process_request(self, request: Request) -> Response | None: - return None - - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [TestMiddleware] + registry.add_middleware(test_middleware_1) # Store original app state original_app_id = id(fastapi_app) - registry = TestRegistry() result = registry.apply_middlewares(fastapi_app) # Verify same app object is returned @@ -493,6 +676,48 @@ def get_middleware_classes(self) -> list[type[Middleware]]: assert result is fastapi_app +class TestMiddlewareConfiguration: + """Test suite for the MiddlewareConfiguration class.""" + + def test_middleware_configuration_inheritance(self): + """ + Test that MiddlewareConfiguration has proper inheritance. + + This test verifies that: + 1. MiddlewareConfiguration inherits from SingleInheritanceRequired + 2. It can be instantiated + """ + class TestConfig(MiddlewareConfiguration): + pass + + config = TestConfig() + assert isinstance(config, MiddlewareConfiguration) + + def test_setup_middlewares_can_be_overridden(self): + """ + Test that setup_middlewares can be overridden to configure middlewares. + + This test verifies that: + 1. setup_middlewares can be overridden + 2. The registry is properly configured when overridden + """ + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestConfig(MiddlewareConfiguration): + def setup_middlewares(self, registry: MiddlewareRegistry) -> None: + registry.add_middleware(TestMiddleware) + + config = TestConfig() + registry = MiddlewareRegistry() + + config.setup_middlewares(registry) + + assert registry.has_middleware(TestMiddleware) + assert registry.get_middleware_count() == 1 + + class TestMiddlewareIntegration: """Integration tests for middleware functionality.""" @@ -523,11 +748,10 @@ async def process_request(self, request: Request) -> Response | None: execution_order.append("second") return None - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [FirstMiddleware, SecondMiddleware] - - registry = TestRegistry() + registry = MiddlewareRegistry() + registry.add_middleware(FirstMiddleware) + registry.add_middleware(SecondMiddleware) + app = registry.apply_middlewares(fastapi_app) # Create a test client to trigger middleware execution @@ -566,11 +790,10 @@ async def process_request(self, request: Request) -> Response | None: execution_order.append("second") return None - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [BlockingMiddleware, SecondMiddleware] - - registry = TestRegistry() + registry = MiddlewareRegistry() + registry.add_middleware(BlockingMiddleware) + registry.add_middleware(SecondMiddleware) + app = registry.apply_middlewares(fastapi_app) @app.get("/test") @@ -587,48 +810,77 @@ async def test_endpoint(): assert response.status_code == 403 assert response.text == "blocked" - def test_middleware_registry_single_inheritance(self): + def test_middleware_registry_with_configuration(self): """ - Test that MiddlewareRegistry enforces single inheritance. + Test using MiddlewareRegistry with MiddlewareConfiguration. This test verifies that: - 1. MiddlewareRegistry implements SingleInheritanceRequired - 2. Multiple inheritance is prevented + 1. MiddlewareConfiguration can configure a MiddlewareRegistry + 2. The configuration is applied correctly """ - # This test assumes SingleInheritanceRequired prevents multiple inheritance - # The actual behavior depends on the implementation of SingleInheritanceRequired + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [] + class TestConfig(MiddlewareConfiguration): + def setup_middlewares(self, registry: MiddlewareRegistry) -> None: + registry.add_middleware(TestMiddleware) - # Should be able to create a single inheritance registry - registry = TestRegistry() - assert isinstance(registry, MiddlewareRegistry) + config = TestConfig() + registry = MiddlewareRegistry() + + config.setup_middlewares(registry) + + assert registry.has_middleware(TestMiddleware) + assert registry.get_middleware_count() == 1 - def test_middleware_type_hints(self): + def test_middleware_execution_order_with_skip_logic(self, fastapi_app): """ - Test that middleware classes have correct type hints. + Test middleware execution order with skip logic. This test verifies that: - 1. get_middleware_classes returns the correct type - 2. process_request has correct parameter and return type hints + 1. Middlewares with skip logic are handled correctly + 2. Order is maintained even when some middlewares skip """ + execution_order = [] - class TestMiddleware(Middleware): + class ConditionalMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + return "/skip" in str(request.url) + async def process_request(self, request: Request) -> Response | None: + execution_order.append("conditional") return None - class TestRegistry(MiddlewareRegistry): - def get_middleware_classes(self) -> list[type[Middleware]]: - return [TestMiddleware] + class AlwaysRunMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("always") + return None - registry = TestRegistry() - middleware_classes = registry.get_middleware_classes() + registry = MiddlewareRegistry() + registry.add_middleware(ConditionalMiddleware) + registry.add_middleware(AlwaysRunMiddleware) + + app = registry.apply_middlewares(fastapi_app) - # Verify type hints - assert isinstance(middleware_classes, list) - assert all( - issubclass(middleware_class, Middleware) - for middleware_class in middleware_classes - ) + @app.get("/test") + async def test_endpoint(): + return {"message": "test"} + + @app.get("/skip") + async def skip_endpoint(): + return {"message": "skip"} + + client = TestClient(app) + + # Test normal endpoint + execution_order.clear() + response = client.get("/test") + assert execution_order == ["always", "conditional"] + assert response.status_code == 200 + + # Test skip endpoint + execution_order.clear() + response = client.get("/skip") + assert execution_order == ["always"] # Only AlwaysRunMiddleware should execute + assert response.status_code == 200 From c5045b8ccd25cef80e1efa5d4561b906de577310 Mon Sep 17 00:00:00 2001 From: William Chen Date: Tue, 22 Jul 2025 20:54:09 +0800 Subject: [PATCH 40/42] Update version to 0.0.25 in __init__.py --- py_spring_core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index 412ce0a..1025556 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -23,7 +23,7 @@ from py_spring_core.event.application_event_publisher import ApplicationEventPublisher from py_spring_core.event.commons import ApplicationEvent -__version__ = "0.0.24" +__version__ = "0.0.25" __all__ = [ "PySpringApplication", From 9a6f7747fe401dae0173c1f818ccacf41acc7853 Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 25 Jul 2025 01:21:12 +0800 Subject: [PATCH 41/42] Fix typos in TypeCheckingError class and method return type in type_checking_service.py; enhance HTTPMethod and RouteRegistration classes with detailed docstrings in route_mapping.py; correct external_dependencies spelling in entity_provider.py. --- .../commons/type_checking_service.py | 6 +- .../entities/controllers/route_mapping.py | 157 ++++++++++++++++++ .../entity_provider/entity_provider.py | 2 +- 3 files changed, 161 insertions(+), 4 deletions(-) diff --git a/py_spring_core/commons/type_checking_service.py b/py_spring_core/commons/type_checking_service.py index f9955cb..d84ca89 100644 --- a/py_spring_core/commons/type_checking_service.py +++ b/py_spring_core/commons/type_checking_service.py @@ -9,7 +9,7 @@ class MypyTypeCheckingError(str, Enum): NoUntypedDefs = "no-untyped-def" -class TypeCheckingErrorr(Exception): ... +class TypeCheckingError(Exception): ... class TypeCheckingService: @@ -20,7 +20,7 @@ def __init__(self, target_folder: str) -> None: MypyTypeCheckingError.NoUntypedDefs ] - def type_checking(self) -> Optional[TypeCheckingErrorr]: + def type_checking(self) -> Optional[TypeCheckingError]: logger.info("[MYPY TYPE CHECKING] Mypy checking types for projects...") # Run mypy and capture stdout and stderr result = subprocess.run( @@ -41,4 +41,4 @@ def type_checking(self) -> Optional[TypeCheckingErrorr]: message += error_message if len(message) == 0: return None - return TypeCheckingErrorr(message) + return TypeCheckingError(message) diff --git a/py_spring_core/core/entities/controllers/route_mapping.py b/py_spring_core/core/entities/controllers/route_mapping.py index c49eca1..0e3f050 100644 --- a/py_spring_core/core/entities/controllers/route_mapping.py +++ b/py_spring_core/core/entities/controllers/route_mapping.py @@ -6,6 +6,11 @@ class HTTPMethod(str, Enum): + """HTTP methods enumeration for route definitions. + + This enum defines the supported HTTP methods that can be used + with route decorators in the PySpring framework. + """ GET = "GET" POST = "POST" PUT = "PUT" @@ -14,6 +19,36 @@ class HTTPMethod(str, Enum): class RouteRegistration(BaseModel): + """Model representing a route registration with all FastAPI-compatible parameters. + + This class encapsulates all the information needed to register a route, + including the HTTP method, path, handler function, and various FastAPI + configuration options. + + Attributes: + class_name (str): Name of the controller class containing the route + method (HTTPMethod): HTTP method for the route + path (str): URL path pattern for the route + func (Callable): The handler function for the route + response_model (Any, optional): Pydantic model for response serialization + status_code (int, optional): Default HTTP status code for successful responses + tags (List[Union[str, Enum]], optional): OpenAPI tags for documentation + dependencies (List[Any], optional): FastAPI dependencies + summary (str, optional): Short summary for OpenAPI documentation + description (str, optional): Detailed description for OpenAPI documentation + response_description (str): Description of successful response + responses (Dict[Union[int, str], Dict[str, Any]], optional): Additional response definitions + deprecated (bool, optional): Whether the endpoint is deprecated + operation_id (str, optional): Unique operation ID for OpenAPI + response_model_include (Set[str], optional): Fields to include in response model + response_model_exclude (Set[str], optional): Fields to exclude from response model + response_model_by_alias (bool): Whether to use field aliases in response + response_model_exclude_unset (bool): Whether to exclude unset fields + response_model_exclude_defaults (bool): Whether to exclude default values + response_model_exclude_none (bool): Whether to exclude None values + include_in_schema (bool): Whether to include in OpenAPI schema + name (str, optional): Custom name for the route + """ class_name: str method: HTTPMethod path: str @@ -38,19 +73,55 @@ class RouteRegistration(BaseModel): name: Optional[str] = None def __eq__(self, other: Any) -> bool: + """Check equality based on HTTP method and path. + + Two route registrations are considered equal if they have + the same HTTP method and path, regardless of other attributes. + + Args: + other (Any): Object to compare with + + Returns: + bool: True if equal, False otherwise + """ if not isinstance(other, RouteRegistration): return False return self.method == other.method and self.path == other.path def __hash__(self) -> int: + """Generate hash based on HTTP method and path. + + This allows RouteRegistration objects to be used in sets + and as dictionary keys. + + Returns: + int: Hash value based on method and path + """ return hash((self.method, self.path)) class RouteMapping: + """Registry for storing and managing route registrations by controller class. + + This class maintains a mapping of controller class names to their + associated route registrations, enabling efficient route lookup + and management within the PySpring framework. + + Attributes: + routes (dict[str, set[RouteRegistration]]): Mapping of class names to route sets + """ routes: dict[str, set[RouteRegistration]] = {} @classmethod def register_route(cls, route_registration: RouteRegistration) -> None: + """Register a route for a specific controller class. + + Adds the given route registration to the routes mapping, + creating a new set for the class if it doesn't exist. + + Args: + route_registration (RouteRegistration): The route to register + """ optional_routes = cls.routes.get(route_registration.class_name, None) if optional_routes is None: cls.routes[route_registration.class_name] = set() @@ -58,6 +129,17 @@ def register_route(cls, route_registration: RouteRegistration) -> None: def _create_route_decorator(method: HTTPMethod): + """Create a route decorator factory for a specific HTTP method. + + This function generates decorator factories that can be used to create + route decorators with FastAPI-compatible parameters. + + Args: + method (HTTPMethod): The HTTP method for routes created by this decorator + + Returns: + Callable: A decorator factory function that accepts route parameters + """ def decorator_factory( path: str, *, @@ -80,7 +162,41 @@ def decorator_factory( include_in_schema: bool = True, name: Optional[str] = None, ): + """Create a route decorator with the specified parameters. + + Args: + path (str): URL path pattern for the route + response_model (Any, optional): Pydantic model for response serialization + status_code (int, optional): Default HTTP status code + tags (List[Union[str, Enum]], optional): OpenAPI tags + dependencies (List[Any], optional): FastAPI dependencies + summary (str, optional): Route summary for documentation + description (str, optional): Route description for documentation + response_description (str): Description of successful response + responses (Dict[Union[int, str], Dict[str, Any]], optional): Additional responses + deprecated (bool, optional): Whether the endpoint is deprecated + operation_id (str, optional): Unique operation ID + response_model_include (Set[str], optional): Fields to include in response + response_model_exclude (Set[str], optional): Fields to exclude from response + response_model_by_alias (bool): Whether to use field aliases + response_model_exclude_unset (bool): Whether to exclude unset fields + response_model_exclude_defaults (bool): Whether to exclude defaults + response_model_exclude_none (bool): Whether to exclude None values + include_in_schema (bool): Whether to include in OpenAPI schema + name (str, optional): Custom name for the route + + Returns: + Callable: A decorator function that registers the route + """ def decorator(func: Callable): + """Decorate a function to register it as a route. + + Args: + func (Callable): The handler function to decorate + + Returns: + Callable: The wrapped function + """ class_name = func.__qualname__.split(".")[0] route_registration = RouteRegistration( class_name=class_name, @@ -119,8 +235,49 @@ def wrapper(*args: Any, **kwargs: Any): return decorator_factory +# Route decorator factories for different HTTP methods GetMapping = _create_route_decorator(HTTPMethod.GET) +"""Decorator factory for creating GET route handlers. + +Example: + @GetMapping("/users") + def get_users(self): + return {"users": []} +""" + PostMapping = _create_route_decorator(HTTPMethod.POST) +"""Decorator factory for creating POST route handlers. + +Example: + @PostMapping("/users", response_model=UserResponse) + def create_user(self, user: UserCreate): + return create_new_user(user) +""" + PutMapping = _create_route_decorator(HTTPMethod.PUT) +"""Decorator factory for creating PUT route handlers. + +Example: + @PutMapping("/users/{user_id}") + def update_user(self, user_id: int, user: UserUpdate): + return update_existing_user(user_id, user) +""" + DeleteMapping = _create_route_decorator(HTTPMethod.DELETE) +"""Decorator factory for creating DELETE route handlers. + +Example: + @DeleteMapping("/users/{user_id}") + def delete_user(self, user_id: int): + delete_existing_user(user_id) + return {"message": "User deleted"} +""" + PatchMapping = _create_route_decorator(HTTPMethod.PATCH) +"""Decorator factory for creating PATCH route handlers. + +Example: + @PatchMapping("/users/{user_id}") + def partial_update_user(self, user_id: int, user: UserPatch): + return patch_existing_user(user_id, user) +""" diff --git a/py_spring_core/core/entities/entity_provider/entity_provider.py b/py_spring_core/core/entities/entity_provider/entity_provider.py index bdaa29c..25d35d5 100644 --- a/py_spring_core/core/entities/entity_provider/entity_provider.py +++ b/py_spring_core/core/entities/entity_provider/entity_provider.py @@ -22,7 +22,7 @@ class EntityProvider: properties_classes: list[Type[Properties]] = field(default_factory=list) rest_controller_classes: list[Type[RestController]] = field(default_factory=list) depends_on: list[Type[AppEntities]] = field(default_factory=list) - extneral_dependencies: list[Any] = field(default_factory=list) + external_dependencies: list[Any] = field(default_factory=list) app_context: Optional["ApplicationContext"] = None def get_entities(self) -> list[Type[AppEntities]]: From 0e74ddd1bb642ec1a06eb4558a44675c087b9505 Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 25 Jul 2025 01:28:45 +0800 Subject: [PATCH 42/42] Refactor method names in PySpringApplication and ApplicationContext for consistency; update component and property container attributes in ContainerManager to improve clarity; adjust related tests to reflect these changes. --- .../context/application_context.py | 88 +++++++++---------- .../core/application/py_spring_application.py | 53 ++++++----- .../core/entities/component/component.py | 2 +- .../interfaces/single_inheritance_required.py | 3 +- tests/test_application_context.py | 30 +++---- tests/test_component_features.py | 24 +++-- 6 files changed, 98 insertions(+), 102 deletions(-) diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index 4d43f0a..938aeaf 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -54,32 +54,32 @@ class ApplicationContextView(BaseModel): """View model for application context state.""" config: ApplicationContextConfig - component_cls_container: list[str] - singleton_component_instance_container: list[str] + component_classes: list[str] + component_instances: list[str] class ContainerManager: """Manages containers for different types of entities in the application context.""" def __init__(self): - self.component_cls_container: dict[str, Type[Component]] = {} - self.controller_cls_container: dict[str, Type[RestController]] = {} - self.singleton_component_instance_container: dict[str, Component] = {} + self.component_classes: dict[str, Type[Component]] = {} + self.controller_classes: dict[str, Type[RestController]] = {} + self.component_instances: dict[str, Component] = {} - self.bean_collection_cls_container: dict[str, Type[BeanCollection]] = {} - self.singleton_bean_instance_container: dict[str, object] = {} + self.bean_collection_classes: dict[str, Type[BeanCollection]] = {} + self.bean_instances: dict[str, object] = {} - self.properties_cls_container: dict[str, Type[Properties]] = {} - self.singleton_properties_instance_container: dict[str, Properties] = {} + self.properties_classes: dict[str, Type[Properties]] = {} + self.properties_instances: dict[str, Properties] = {} def is_entity_in_container(self, entity_cls: Type[AppEntities]) -> bool: """Check if an entity class is registered in any container.""" cls_name = entity_cls.__name__ return ( - cls_name in self.component_cls_container - or cls_name in self.controller_cls_container - or cls_name in self.bean_collection_cls_container - or cls_name in self.properties_cls_container + cls_name in self.component_classes + or cls_name in self.controller_classes + or cls_name in self.bean_collection_classes + or cls_name in self.properties_classes ) @@ -241,7 +241,7 @@ def register_component(self, component_cls: Type[Component]) -> None: ) component_cls_name = component_cls.get_name() - existing_component = self.container_manager.component_cls_container.get( + existing_component = self.container_manager.component_classes.get( component_cls_name ) @@ -253,7 +253,7 @@ def register_component(self, component_cls: Type[Component]) -> None: ): return - self.container_manager.component_cls_container[component_cls_name] = ( + self.container_manager.component_classes[component_cls_name] = ( component_cls ) @@ -266,7 +266,7 @@ def get_component( target_cls_name = self._determine_target_cls_name(component_cls, qualifier) - if target_cls_name not in self.container_manager.component_cls_container: + if target_cls_name not in self.container_manager.component_classes: return None scope = component_cls.get_scope() @@ -274,7 +274,7 @@ def get_component( case ComponentScope.Singleton: return cast( T, - self.container_manager.singleton_component_instance_container.get( + self.container_manager.component_instances.get( target_cls_name ), ) @@ -340,7 +340,7 @@ def _init_abstract_component_subclasses(self, component_cls: Type[ABC]) -> None: subclass_component_cls, subclass_component_cls.get_name() ) if instance is not None: - self.container_manager.singleton_component_instance_container[ + self.container_manager.component_instances[ subclass_component_cls.get_name() ] = instance @@ -349,7 +349,7 @@ def init_singleton_components(self) -> None: for ( component_cls_name, component_cls, - ) in self.container_manager.component_cls_container.items(): + ) in self.container_manager.component_classes.items(): if component_cls.get_scope() != ComponentScope.Singleton: continue @@ -364,7 +364,7 @@ def init_singleton_components(self) -> None: component_cls, component_cls_name ) if instance is not None: - self.container_manager.singleton_component_instance_container[ + self.container_manager.component_instances[ component_cls_name ] = instance @@ -389,18 +389,18 @@ def register_bean_collection(self, bean_cls: Type[BeanCollection]) -> None: ) bean_name = bean_cls.get_name() - self.container_manager.bean_collection_cls_container[bean_name] = bean_cls + self.container_manager.bean_collection_classes[bean_name] = bean_cls def get_bean( self, object_cls: Type[T], qualifier: Optional[str] = None ) -> Optional[T]: """Get a bean instance by class and optional qualifier.""" bean_name = object_cls.__name__ - if bean_name not in self.container_manager.singleton_bean_instance_container: + if bean_name not in self.container_manager.bean_instances: return None return cast( - T, self.container_manager.singleton_bean_instance_container.get(bean_name) + T, self.container_manager.bean_instances.get(bean_name) ) def _inject_bean_collection_dependencies( @@ -414,7 +414,7 @@ def _inject_bean_collection_dependencies( def _validate_bean_view(self, view: BeanView, collection_name: str) -> None: """Validate a bean view before adding it to the container.""" - if view.bean_name in self.container_manager.singleton_bean_instance_container: + if view.bean_name in self.container_manager.bean_instances: raise BeanConflictError( f"[BEAN CONFLICTS] Bean: {view.bean_name} already exists under collection: {collection_name}" ) @@ -430,7 +430,7 @@ def init_singleton_beans(self) -> None: for ( bean_collection_cls_name, bean_collection_cls, - ) in self.container_manager.bean_collection_cls_container.items(): + ) in self.container_manager.bean_collection_classes.items(): logger.debug( f"[INITIALIZING SINGLETON BEAN] Init singleton bean: {bean_collection_cls_name}" ) @@ -441,7 +441,7 @@ def init_singleton_beans(self) -> None: bean_views = collection.scan_beans() for view in bean_views: self._validate_bean_view(view, collection.get_name()) - self.container_manager.singleton_bean_instance_container[ + self.container_manager.bean_instances[ view.bean_name ] = view.bean @@ -464,19 +464,19 @@ def register_properties(self, properties_cls: Type[Properties]) -> None: ) properties_name = properties_cls.get_key() - self.container_manager.properties_cls_container[properties_name] = ( + self.container_manager.properties_classes[properties_name] = ( properties_cls ) def get_properties(self, properties_cls: Type[PT]) -> Optional[PT]: """Get a properties instance by class.""" properties_cls_name = properties_cls.get_key() - if properties_cls_name not in self.container_manager.properties_cls_container: + if properties_cls_name not in self.container_manager.properties_classes: return None return cast( PT, - self.container_manager.singleton_properties_instance_container.get( + self.container_manager.properties_instances.get( properties_cls_name ), ) @@ -485,7 +485,7 @@ def _create_properties_loader(self) -> _PropertiesLoader: """Create a properties loader instance.""" return _PropertiesLoader( self.config.properties_path, - list(self.container_manager.properties_cls_container.values()), + list(self.container_manager.properties_classes.values()), ) def load_properties(self) -> None: @@ -496,10 +496,10 @@ def load_properties(self) -> None: for ( properties_key, properties_cls, - ) in self.container_manager.properties_cls_container.items(): + ) in self.container_manager.properties_classes.items(): if ( properties_key - in self.container_manager.singleton_properties_instance_container + in self.container_manager.properties_instances ): continue @@ -515,13 +515,13 @@ def load_properties(self) -> None: f"with key: {properties_cls.get_key()}" ) - self.container_manager.singleton_properties_instance_container[ + self.container_manager.properties_instances[ properties_key ] = optional_properties # Update the global properties loader reference _PropertiesLoader.optional_loaded_properties = ( - self.container_manager.singleton_properties_instance_container + self.container_manager.properties_instances ) @@ -565,11 +565,11 @@ def as_view(self) -> ApplicationContextView: """Create a view model of the application context state.""" return ApplicationContextView( config=self.config, - component_cls_container=list( - self.container_manager.component_cls_container.keys() + component_classes=list( + self.container_manager.component_classes.keys() ), - singleton_component_instance_container=list( - self.container_manager.singleton_component_instance_container.keys() + component_instances=list( + self.container_manager.component_instances.keys() ), ) @@ -618,25 +618,25 @@ def register_controller(self, controller_cls: Type[RestController]) -> None: ) controller_cls_name = controller_cls.get_name() - self.container_manager.controller_cls_container[controller_cls_name] = ( + self.container_manager.controller_classes[controller_cls_name] = ( controller_cls ) def get_controller_instances(self) -> list[RestController]: """Get all controller instances.""" return [ - cls() for cls in self.container_manager.controller_cls_container.values() + cls() for cls in self.container_manager.controller_classes.values() ] def get_singleton_component_instances(self) -> list[Component]: """Get all singleton component instances.""" return list( - self.container_manager.singleton_component_instance_container.values() + self.container_manager.component_instances.values() ) def get_singleton_bean_instances(self) -> list[object]: """Get all singleton bean instances.""" - return list(self.container_manager.singleton_bean_instance_container.values()) + return list(self.container_manager.bean_instances.values()) def is_within_context(self, entity_cls: Type[AppEntities]) -> bool: """Check if an entity class is registered in the application context.""" @@ -663,8 +663,8 @@ def inject_dependencies_for_external_object(self, object: Type[Any]) -> None: def inject_dependencies_for_app_entities(self) -> None: """Inject dependencies for all registered app entities.""" containers: list[Mapping[str, Type[AppEntities]]] = [ - self.container_manager.component_cls_container, - self.container_manager.controller_cls_container, + self.container_manager.component_classes, + self.container_manager.controller_classes, ] for container in containers: diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index cb1dd64..1d18b17 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -111,7 +111,7 @@ def __init__( ) self.shutdown_handler: Optional[GracefulShutdownHandler] = None - def __configure_logging(self): + def _configure_logging(self): """Applies the logging configuration using Loguru.""" config = self.app_config.loguru_config if not config.log_file_path: @@ -125,7 +125,7 @@ def __configure_logging(self): retention=config.log_retention, serialize=config.format == LogFormat.JSON, ) - self.__configure_uvicorn_logging() + self._configure_uvicorn_logging() def _get_system_managed_classes(self) -> Iterable[Type[Component]]: return [ApplicationEventPublisher, ApplicationEventHandlerRegistry] @@ -205,7 +205,7 @@ def _prepare_injected_classes(self) -> Iterable[Type[object]]: ] return classes_to_inject - def __init_app(self) -> None: + def _init_app(self) -> None: classes_to_inject = self._prepare_injected_classes() self._inject_application_context_to_context_required(classes_to_inject) self._register_app_entities(classes_to_inject) @@ -229,7 +229,7 @@ def _handle_singleton_components_life_cycle( case ComponentLifeCycle.Destruction: component.finish_destruction_cycle() - def __init_controllers(self) -> None: + def _init_controllers(self) -> None: controllers = self.app_context.get_controller_instances() for controller in controllers: name = controller.__class__.__name__ @@ -245,7 +245,6 @@ def _init_external_handler(self, base_class: Type[SingleInheritanceRequiredT]) - Args: base_class: The base class to get subclass from - handler_type: The type of handler for logging purposes Returns: The initialized handler class or None if no handler is found @@ -254,37 +253,35 @@ def _init_external_handler(self, base_class: Type[SingleInheritanceRequiredT]) - RuntimeError: If the handler has unimplemented abstract methods """ handler_type = base_class.__name__ - self_defined_handler_cls = base_class.get_subclass() - if self_defined_handler_cls is None: + handler_cls = base_class.get_subclass() + if handler_cls is None: logger.debug(f"[{handler_type} INIT] No self defined {handler_type.lower()} class found") return None - unimplemented_abstract_methods = framework_utils.get_unimplemented_abstract_methods(self_defined_handler_cls) + unimplemented_abstract_methods = framework_utils.get_unimplemented_abstract_methods(handler_cls) if len(unimplemented_abstract_methods) > 0: - error_message = f"[{handler_type} INIT] Self defined {handler_type.lower()} class: {self_defined_handler_cls.__name__} has unimplemented abstract methods: {unimplemented_abstract_methods}" + error_message = f"[{handler_type} INIT] Self defined {handler_type.lower()} class: {handler_cls.__name__} has unimplemented abstract methods: {unimplemented_abstract_methods}" logger.error(error_message) raise RuntimeError(error_message) logger.debug( - f"[{handler_type} INIT] Self defined {handler_type.lower()} class: {self_defined_handler_cls.__name__}" + f"[{handler_type} INIT] Self defined {handler_type.lower()} class: {handler_cls.__name__}" ) logger.debug( - f"[{handler_type} INIT] Inject dependencies for external object: {self_defined_handler_cls.__name__}" - ) - self.app_context.inject_dependencies_for_external_object( - self_defined_handler_cls + f"[{handler_type} INIT] Inject dependencies for external object: {handler_cls.__name__}" ) - return self_defined_handler_cls + self.app_context.inject_dependencies_for_external_object(handler_cls) + return handler_cls - def __init_middlewares(self) -> None: + def _init_middlewares(self) -> None: handler_type = MiddlewareRegistry.__name__ logger.debug(f"[{handler_type} INIT] Initialize middlewares...") - middeware_configuration_cls = self._init_external_handler(MiddlewareConfiguration) - if middeware_configuration_cls is None: + middleware_config_cls = self._init_external_handler(MiddlewareConfiguration) + if middleware_config_cls is None: return registry = MiddlewareRegistry() logger.info(f"[{handler_type} INIT] Setup middlewares for registry: {registry.__class__.__name__}") - middeware_configuration_cls().configure_middlewares(registry) + middleware_config_cls().configure_middlewares(registry) logger.info(f"[{handler_type} INIT] Middlewares setup for registry: {registry.__class__.__name__} completed") middleware_classes: list[Type[Middleware]] = registry.get_middleware_classes() logger.info(f"[{handler_type} INIT] Middleware classes: {', '.join([middleware_class.__name__ for middleware_class in middleware_classes])}") @@ -296,7 +293,7 @@ def __init_middlewares(self) -> None: registry.apply_middlewares(self.fastapi) logger.debug(f"[{handler_type} INIT] Middlewares initialized") - def __init_graceful_shutdown(self) -> None: + def _init_graceful_shutdown(self) -> None: handler_type = GracefulShutdownHandler.__name__ logger.debug(f"[{handler_type} INIT] Initialize graceful shutdown...") handler_cls: Optional[Type[GracefulShutdownHandler]] = self._init_external_handler(GracefulShutdownHandler) @@ -314,7 +311,7 @@ def __init_graceful_shutdown(self) -> None: ) # type: ignore logger.debug(f"[{handler_type} INIT] Graceful shutdown initialized") - def __configure_uvicorn_logging(self): + def _configure_uvicorn_logging(self): """Configure Uvicorn to use Loguru instead of default logging.""" # Configure Uvicorn to use Loguru @@ -341,7 +338,7 @@ def emit(self, record): log_level = self.app_config.loguru_config.log_level.value logging.basicConfig(handlers=[InterceptHandler()], level=log_level, force=True) - def __run_server(self) -> None: + def _run_server(self) -> None: # Run uvicorn server uvicorn.run( self.fastapi, @@ -352,13 +349,13 @@ def __run_server(self) -> None: def run(self) -> None: try: - self.__configure_logging() - self.__init_app() - self.__init_controllers() - self.__init_middlewares() - self.__init_graceful_shutdown() + self._configure_logging() + self._init_app() + self._init_controllers() + self._init_middlewares() + self._init_graceful_shutdown() if self.app_config.server_config.enabled: - self.__run_server() + self._run_server() finally: # Handle component lifecycle destruction self._handle_singleton_components_life_cycle(ComponentLifeCycle.Destruction) diff --git a/py_spring_core/core/entities/component/component.py b/py_spring_core/core/entities/component/component.py index 3a3aba0..55939f6 100644 --- a/py_spring_core/core/entities/component/component.py +++ b/py_spring_core/core/entities/component/component.py @@ -34,7 +34,7 @@ class Component: The `get_scope()` and `set_scope()` methods allow you to get and set the scope of the component. The lifecycle hooks are: - - `post_initialize()`: Called after the component is initialized. + - `post_construct()`: Called after the component is initialized. - `pre_destroy()`: Called before the component is destroyed. The `finish_initialization_cycle()` and `finish_destruction_cycle()` methods are final and call the corresponding lifecycle hooks in the correct order. diff --git a/py_spring_core/core/interfaces/single_inheritance_required.py b/py_spring_core/core/interfaces/single_inheritance_required.py index 211a39b..525ea1a 100644 --- a/py_spring_core/core/interfaces/single_inheritance_required.py +++ b/py_spring_core/core/interfaces/single_inheritance_required.py @@ -6,7 +6,8 @@ class SingleInheritanceRequired(Generic[T], ABC): """ - A singleton component is a component that only allow subclasses to be inherited. + A base class that ensures only one subclass can be inherited from it. + This enforces the single inheritance constraint for specific component types. """ @classmethod diff --git a/tests/test_application_context.py b/tests/test_application_context.py index 9da4569..ae159c6 100644 --- a/tests/test_application_context.py +++ b/tests/test_application_context.py @@ -37,26 +37,26 @@ class TestProperties(Properties): app_context.register_properties(TestProperties) assert ( - "TestComponent" in app_context.container_manager.component_cls_container - and app_context.container_manager.component_cls_container["TestComponent"] + "TestComponent" in app_context.container_manager.component_classes + and app_context.container_manager.component_classes["TestComponent"] == TestComponent ) assert ( - "TestController" in app_context.container_manager.controller_cls_container - and app_context.container_manager.controller_cls_container["TestController"] + "TestController" in app_context.container_manager.controller_classes + and app_context.container_manager.controller_classes["TestController"] == TestController ) assert ( "TestBeanCollection" - in app_context.container_manager.bean_collection_cls_container - and app_context.container_manager.bean_collection_cls_container[ + in app_context.container_manager.bean_collection_classes + and app_context.container_manager.bean_collection_classes[ "TestBeanCollection" ] == TestBeanCollection ) assert ( - "test_properties" in app_context.container_manager.properties_cls_container - and app_context.container_manager.properties_cls_container[ + "test_properties" in app_context.container_manager.properties_classes + and app_context.container_manager.properties_classes[ "test_properties" ] == TestProperties @@ -77,16 +77,16 @@ class TestProperties(Properties): app_context.register_bean_collection(TestBeanCollection) app_context.register_properties(TestProperties) - assert "TestComponent" in app_context.container_manager.component_cls_container + assert "TestComponent" in app_context.container_manager.component_classes assert ( - "TestController" in app_context.container_manager.controller_cls_container + "TestController" in app_context.container_manager.controller_classes ) assert ( "TestBeanCollection" - in app_context.container_manager.bean_collection_cls_container + in app_context.container_manager.bean_collection_classes ) assert ( - "test_properties" in app_context.container_manager.properties_cls_container + "test_properties" in app_context.container_manager.properties_classes ) def test_register_invalid_entities_raises_error( @@ -142,7 +142,7 @@ class TestProperties(Properties): # Test retrieving singleton components component_instance = TestComponent() - app_context.container_manager.singleton_component_instance_container[ + app_context.container_manager.component_instances[ "TestComponent" ] = component_instance retrieved_component = app_context.get_component(TestComponent, None) @@ -150,7 +150,7 @@ class TestProperties(Properties): # Test retrieving singleton beans bean_instance = TestBeanCollection() - app_context.container_manager.singleton_bean_instance_container[ + app_context.container_manager.bean_instances[ "TestBeanCollection" ] = bean_instance retrieved_bean = app_context.get_bean(TestBeanCollection, None) @@ -158,7 +158,7 @@ class TestProperties(Properties): # Test retrieving singleton properties properties_instance = TestProperties() - app_context.container_manager.singleton_properties_instance_container[ + app_context.container_manager.properties_instances[ "test_properties" ] = properties_instance retrieved_properties = app_context.get_properties(TestProperties) diff --git a/tests/test_component_features.py b/tests/test_component_features.py index 1aa26b1..000d7f8 100644 --- a/tests/test_component_features.py +++ b/tests/test_component_features.py @@ -107,17 +107,16 @@ class Config: def process(self) -> str: return "Test service processing" - # Register component first time + # Register the same component multiple times + app_context.register_component(TestService) + initial_count = len(app_context.container_manager.component_classes) app_context.register_component(TestService) - initial_count = len(app_context.container_manager.component_cls_container) - - # Register same component again - should be silently skipped app_context.register_component(TestService) - final_count = len(app_context.container_manager.component_cls_container) + final_count = len(app_context.container_manager.component_classes) - # Verify component count didn't change (no duplicate registration) - assert final_count == initial_count - assert "TestService" in app_context.container_manager.component_cls_container + # Verify the component is registered only once + assert initial_count == final_count == 1 + assert "TestService" in app_context.container_manager.component_classes def test_component_name_override(self, app_context: ApplicationContext): """ @@ -145,13 +144,12 @@ def process(self) -> str: app_context.register_component(TestService) app_context.init_ioc_container() - # Verify component is registered with custom name + # Check if component is registered with custom name assert ( - "CustomServiceName" in app_context.container_manager.component_cls_container + "CustomServiceName" in app_context.container_manager.component_classes ) - assert ( - app_context.container_manager.component_cls_container["CustomServiceName"] - == TestService + component_instance = ( + app_context.container_manager.component_classes["CustomServiceName"] ) def test_qualifier_with_invalid_component(self, app_context: ApplicationContext):