diff --git a/.github/workflows/pypi-deployment.yaml b/.github/workflows/pypi-deployment.yaml index 8ec2e1f..16f19ed 100644 --- a/.github/workflows/pypi-deployment.yaml +++ b/.github/workflows/pypi-deployment.yaml @@ -8,6 +8,7 @@ on: jobs: publish: runs-on: ubuntu-latest + if: ${{ !contains(github.event.head_commit.message, 'skip ci') && !contains(github.event.pull_request.title, 'skip ci') && (github.ref == 'refs/heads/main' || github.ref_type == 'tag') }} steps: - name: Checkout code @@ -25,7 +26,11 @@ jobs: - name: Install dependencies run: | - pdm install --prod + pdm install + + - name: Run tests + run: | + pdm run pytest - name: Build the package run: | diff --git a/.gitignore b/.gitignore index efa407c..280fb4c 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,10 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ + +src/ +app-config.json +application-properties.json +main.py +pdm.lock \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..c3bc92f --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 PythonSpring + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 47bf0a5..aa1cbfd 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,9 @@ # **PySpring** Framework -#### **PySpring** is a Python web framework inspired by Spring Boot. It combines FastAPI for the web layer, SQLModel for ORM, and Pydantic for data validation. PySpring provides a structured approach to building scalable web applications with `auto dependency injection`, `auto configuration management` and a `web server` for hosting your application. +#### **PySpring** is a Python web framework inspired by Spring Boot. It combines FastAPI for the web layer, and Pydantic for data validation. PySpring provides a structured approach to building scalable web applications with `auto dependency injection`, `auto configuration management` and a `web server` for hosting your application. -# Key Features -- Application Initialization: **PySpringApplication** class serves as the main entry point for the **PySpring** application. It initializes the application from a configuration file, scans the application source directory for Python files, and groups them into class files and model files - -- **Model Import and Table Creation**: **PySpring** dynamically imports model modules and creates SQLModel tables based on the imported models. It supports SQLAlchemy for database operations. +## Key Features +- **Application Initialization**: `PySpringApplication` class serves as the main entry point for the **PySpring** application. It initializes the application from a configuration file, scans the application source directory for Python files, and groups them into class files and model files. - **Application Context Management**: **PySpring** manages the application context and dependency injection. It registers application entities such as components, controllers, bean collections, and properties. It also initializes the application context and injects dependencies. @@ -19,36 +17,93 @@ - **Builtin FastAPI Integration**: **PySpring** integrates with `FastAPI`, a modern, fast (high-performance), web framework for building APIs with Python. It leverages FastAPI's features for routing, request handling, and server configuration. -# Getting Started -To get started with **PySpring**, follow these steps: - -### 1. Install the **PySpring** framework by running: - -`pip3 install git+https://github.com/PythonSpring/pyspring-core.git` - +## Project Structure +``` +PySpring/ +├── src/ # Source code directory +├── tests/ # Test files +├── logs/ # Application logs +├── py_spring_core/ # Core framework package +├── app-config.json # Application configuration +├── application-properties.json # Application properties +├── main.py # Application entry point +├── pyproject.toml # Project metadata and dependencies +└── README.md # Project documentation +``` -### 2. Create a new Python project and navigate to its directory +## Getting Started -- Implement your application properties, components, controllers, using **PySpring** conventions inside declared source code folder (whcih can be modified the key `app_src_target_dir` inside app-config.json), this controls what folder will be scanned by the framework. +### Prerequisites +- Python 3.10 or higher +- pip (Python package installer) -- Instantiate a `PySpringApplication` object in your main script, passing the path to your application configuration file. +### Installation +1. Install the **PySpring** framework: +```bash +pip install py-spring-core +``` -- Optionally, define and enable any framework modules you want to use. +2. Create a new Python project and navigate to its directory -- Run the application by calling the `run()` method on the `PySpringApplication` object, as shown in the example code below: +3. Set up your application: + - Implement your application properties, components, and controllers using **PySpring** conventions inside the declared source code folder (which can be modified via the `app_src_target_dir` key in `app-config.json`) + - Create an `app-config.json` file for your application configuration + - Create an `application-properties.json` file for your application properties +4. Create your main application script: ```python -from py_spring import PySpringApplication +from py_spring_core import PySpringApplication def main(): - app = PySpringApplication() + app = PySpringApplication("./app-config.json") app.run() if __name__ == "__main__": main() ``` -- For example project, please refer to this [github repo](https://github.com/NFUChen/PySpring-Example-Project). -# Contributing +5. Run your application: +```bash +python main.py +``` + +For a complete example project, please refer to the [PySpring Example Project](https://github.com/NFUChen/PySpring-Example-Project). + +## Development Setup +1. Clone the repository: +```bash +git clone https://github.com/NFUChen/PySpring.git +cd PySpring +``` + +2. Install development dependencies: +```bash +pip install -e ".[dev]" +``` + +3. Run tests: +```bash +pytest +``` + +## Dependencies +PySpring relies on several key dependencies: +- FastAPI (0.112.0) +- Pydantic (2.8.2) +- Uvicorn (0.30.5) +- Loguru (0.7.2) +- And other supporting packages + +For a complete list of dependencies, see `pyproject.toml`. + +## Contributing + +Contributions to **PySpring** are welcome! If you find any issues or have suggestions for improvements, please: +1. Fork the repository +2. Create a feature branch +3. Commit your changes +4. Push to the branch +5. Create a Pull Request -Contributions to **PySpring** are welcome! If you find any issues or have suggestions for improvements, please submit a pull request or open an issue on GitHub. \ No newline at end of file +## License +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. diff --git a/py_spring_core/__init__.py b/py_spring_core/__init__.py index de4cba4..1025556 100644 --- a/py_spring_core/__init__.py +++ b/py_spring_core/__init__.py @@ -1,8 +1,50 @@ from py_spring_core.core.application.py_spring_application import PySpringApplication -from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component, ComponentScope +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection +from py_spring_core.core.entities.component.component import Component, ComponentScope from py_spring_core.core.entities.controllers.rest_controller import RestController +from py_spring_core.core.entities.controllers.route_mapping import ( + DeleteMapping, + GetMapping, + PatchMapping, + PostMapping, + PutMapping, +) +from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.entities.middlewares.middleware_registry import ( + MiddlewareRegistry, MiddlewareConfiguration +) from py_spring_core.core.entities.properties.properties import Properties -from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.interfaces.application_context_required import ( + ApplicationContextRequired, +) +from py_spring_core.core.interfaces.graceful_shutdown_handler import GracefulShutdownHandler, ShutdownType +from py_spring_core.event.application_event_handler_registry import EventListener +from py_spring_core.event.application_event_publisher import ApplicationEventPublisher +from py_spring_core.event.commons import ApplicationEvent -__version__ = "0.0.4" \ No newline at end of file +__version__ = "0.0.25" + +__all__ = [ + "PySpringApplication", + "BeanCollection", + "Component", + "ComponentScope", + "RestController", + "DeleteMapping", + "GetMapping", + "PatchMapping", + "PostMapping", + "PutMapping", + "EntityProvider", + "Properties", + "ApplicationContextRequired", + "ApplicationEventPublisher", + "ApplicationEvent", + "EventListener", + "Middleware", + "MiddlewareRegistry", + "MiddlewareConfiguration", + "GracefulShutdownHandler", + "ShutdownType" +] \ No newline at end of file diff --git a/py_spring_core/commons/class_scanner.py b/py_spring_core/commons/class_scanner.py index dfd4496..178d366 100644 --- a/py_spring_core/commons/class_scanner.py +++ b/py_spring_core/commons/class_scanner.py @@ -1,9 +1,11 @@ import ast -import importlib.util -from typing import Iterable, Optional, Type +import re +from typing import Any, Iterable, Type from loguru import logger +from .module_importer import ModuleImporter + class ClassScanner: """ @@ -17,6 +19,8 @@ class ClassScanner: def __init__(self, file_paths: Iterable[str]) -> None: self.file_paths = file_paths self.scanned_classes: dict[str, dict[str, Type[object]]] = {} + # Use ModuleImporter for handling module imports + self._module_importer = ModuleImporter() def extract_classes_from_file(self, file_path: str) -> dict[str, Type[object]]: with open(file_path, "r") as file: @@ -44,8 +48,12 @@ def _extract_classes_from_file_content( return class_objects - def scan_classes_for_file_paths(self) -> None: + def scan_classes_for_file_paths(self, exclude_file_patterns: Iterable[str]) -> None: for file_path in self.file_paths: + if any(re.match(pattern, file_path) for pattern in exclude_file_patterns): + logger.debug(f"[EXCLUDED FILE] {file_path} is excluded") + continue + object_cls_dict: dict[str, Type[object]] = self.extract_classes_from_file( file_path ) @@ -53,14 +61,12 @@ def scan_classes_for_file_paths(self) -> None: def import_class_from_file( self, file_path: str, class_name: str - ) -> Optional[Type[object]]: - spec = importlib.util.spec_from_file_location(class_name, file_path) - if spec is None: - return - module = importlib.util.module_from_spec(spec) - if spec.loader is None: - return - spec.loader.exec_module(module) + ) -> Type[object] | None: + # Use ModuleImporter to handle module import + module = self._module_importer.import_module_from_path(file_path) + if module is None: + return None + cls = getattr(module, class_name, None) return cls @@ -79,3 +85,11 @@ def display_classes(self) -> None: repr += f" Class: {class_name}\n" logger.debug(repr) + + def clear_module_cache(self) -> None: + """Clear the module cache. Useful for testing or when you need to force re-import.""" + self._module_importer.clear_cache() + + def get_cache_size(self) -> int: + """Get the number of cached modules.""" + return self._module_importer.get_cache_size() diff --git a/py_spring_core/commons/config_file_template_generator/templates.py b/py_spring_core/commons/config_file_template_generator/templates.py index fdb974f..6160d15 100644 --- a/py_spring_core/commons/config_file_template_generator/templates.py +++ b/py_spring_core/commons/config_file_template_generator/templates.py @@ -1,8 +1,12 @@ +from typing import Any + app_config_template = { "app_src_target_dir": "./src", "server_config": {"host": "0.0.0.0", "port": 8080, "enabled": True}, "properties_file_path": "./application-properties.json", "loguru_config": {"log_file_path": "./logs/app.log", "log_level": "DEBUG"}, + "type_checking_mode": "strict", + "shutdown_config": {"timeout_seconds": 30.0, "enabled": True}, } -app_properties_template = {} +app_properties_template: dict[str, Any] = {} diff --git a/py_spring_core/commons/json_config_repository.py b/py_spring_core/commons/json_config_repository.py index cf5e6f3..65ef43c 100644 --- a/py_spring_core/commons/json_config_repository.py +++ b/py_spring_core/commons/json_config_repository.py @@ -24,14 +24,14 @@ class JsonConfigRepository(Generic[T]): """ def __init__(self, file_path: str, target_key: Optional[str] = None) -> None: - self.base_model_cls: BaseModel = self.__class__._get_model_cls() # type: ignore + self.base_model_cls: Type[T] = self.__class__._get_model_cls() self.file_path = file_path self.target_key = target_key self._config: T = self._load_config() @classmethod def _get_model_cls(cls) -> Type[T]: - return get_args(tp=cls.__mro__[0].__orig_bases__[0])[0] + return get_args(cls.__orig_bases__[0])[0] # type: ignore def get_config(self) -> T: return self._config @@ -41,12 +41,12 @@ def reload_config(self) -> None: def save_config(self) -> None: is_the_same_class = ( - self._config.__class__.__name__ == self.base_model_cls.__name__ # type: ignore - ) # type: ignore + self._config.__class__.__name__ == self.base_model_cls.__name__ + ) if not is_the_same_class: raise TypeError( - f"[BASE MODEL CLASS TYPE MISMATCH] Base model class of current repository: {self.base_model_cls.__name__} mismatch with config class: {self._config.__class__.__name__}" # type: ignore - ) # type: ignore + f"[BASE MODEL CLASS TYPE MISMATCH] Base model class of current repository: {self.base_model_cls.__name__} mismatch with config class: {self._config.__class__.__name__}" + ) with open(self.file_path, "w") as file: file.write(self._config.model_dump_json(indent=4)) @@ -56,16 +56,16 @@ def save_config_to_target_path(self, file_path: str) -> None: def _load_config(self) -> T: with open(self.file_path, "r") as file: - if BaseModel not in self.base_model_cls.__mro__: # type: ignore + if BaseModel not in self.base_model_cls.__mro__: raise TypeError( "[BASE MODEL INHERITANCE REQUIRED] JsonConfigRepository required model class being inherited from pydantic.BaseModel for marshalling JSON into python object." ) if self.target_key is None: - return self.base_model_cls.model_validate_json(file.read()) # type: ignore + return self.base_model_cls.model_validate_json(file.read()) target_py_object = json.loads(file.read()) if self.target_key not in target_py_object: raise ValueError( f"[TARGET KEY NOT FOUND] Target key: {self.target_key} not found" ) - return self.base_model_cls.model_validate(target_py_object[self.target_key]) # type: ignore + return self.base_model_cls.model_validate(target_py_object[self.target_key]) diff --git a/py_spring_core/commons/module_importer.py b/py_spring_core/commons/module_importer.py new file mode 100644 index 0000000..6b20108 --- /dev/null +++ b/py_spring_core/commons/module_importer.py @@ -0,0 +1,131 @@ +import importlib.util +import inspect +from pathlib import Path +from typing import Any, Iterable, Type, Optional + +from loguru import logger + + +class ModuleImporter: + """ + A class that handles dynamic module importing with caching capabilities. + Provides functionality to import modules from file paths and extract classes from them. + """ + + def __init__(self) -> None: + # Module cache to prevent duplicate imports + self._module_cache: dict[str, Any] = {} + + def import_module_from_path(self, file_path: str) -> Optional[Any]: + """ + Import a module from a file path with caching. + + Args: + file_path (str): The file path of the module to import. + + Returns: + Optional[Any]: The imported module or None if import fails. + """ + resolved_path = Path(file_path).resolve() + cache_key = str(resolved_path) + module_name = resolved_path.stem + + # Check if module is already cached + if cache_key in self._module_cache: + logger.debug(f"[MODULE CACHE] Using cached module: {module_name}") + return self._module_cache[cache_key] + + logger.info(f"[MODULE IMPORT] Import module path: {resolved_path}") + + # Create a module specification + spec = importlib.util.spec_from_file_location(module_name, resolved_path) + if spec is None: + logger.warning(f"[MODULE IMPORT] Could not create spec for {module_name}") + return None + + # Create a new module based on the specification + module = importlib.util.module_from_spec(spec) + if spec.loader is None: + logger.warning(f"[MODULE IMPORT] No loader found for {module_name}") + return None + + # Execute the module in its own namespace + logger.info(f"[MODULE IMPORT] Import module: {module_name}") + try: + spec.loader.exec_module(module) + logger.success(f"[MODULE IMPORT] Successfully imported {module_name}") + # Cache the module + self._module_cache[cache_key] = module + return module + except Exception as error: + logger.warning(f"[MODULE IMPORT] Failed to import {module_name}: {error}") + raise error + + def extract_classes_from_module(self, module: Any) -> list[Type[object]]: + """ + Extract all classes from a module. + + Args: + module (Any): The module to extract classes from. + + Returns: + list[Type[object]]: List of classes found in the module. + """ + loaded_classes = [] + for attr in dir(module): + obj = getattr(module, attr) + if attr.startswith("__"): + continue + if not inspect.isclass(obj): + continue + loaded_classes.append(obj) + return loaded_classes + + def import_classes_from_paths( + self, + file_paths: Iterable[str], + target_subclasses: Iterable[Type[object]] = [] + ) -> set[Type[object]]: + """ + Import classes from multiple file paths with optional filtering. + + Args: + file_paths (Iterable[str]): The file paths of the modules to import. + target_subclasses (Iterable[Type[object]], optional): Target subclasses to filter. Defaults to []. + + Returns: + set[Type[object]]: Set of imported classes. + """ + all_loaded_classes: list[Type[object]] = [] + + for file_path in file_paths: + module = self.import_module_from_path(file_path) + if module is None: + raise ImportError(f"Failed to import module from {file_path}") + + loaded_classes = self.extract_classes_from_module(module) + all_loaded_classes.extend(loaded_classes) + + returned_target_classes: set[Type[object]] = set() + + # If no target subclasses specified, return all loaded classes + if not target_subclasses: + returned_target_classes = set(all_loaded_classes) + else: + # Filter classes based on target subclasses + for target_cls in target_subclasses: + for loaded_class in all_loaded_classes: + if loaded_class in target_subclasses: + continue + if issubclass(loaded_class, target_cls): + returned_target_classes.add(loaded_class) + + return returned_target_classes + + def clear_cache(self) -> None: + """Clear the module cache. Useful for testing or when you need to force re-import.""" + self._module_cache.clear() + + def get_cache_size(self) -> int: + """Get the number of cached modules.""" + return len(self._module_cache) \ No newline at end of file diff --git a/py_spring_core/commons/type_checking_service.py b/py_spring_core/commons/type_checking_service.py new file mode 100644 index 0000000..d84ca89 --- /dev/null +++ b/py_spring_core/commons/type_checking_service.py @@ -0,0 +1,44 @@ +import subprocess +from enum import Enum +from typing import Optional + +from loguru import logger + + +class MypyTypeCheckingError(str, Enum): + NoUntypedDefs = "no-untyped-def" + + +class TypeCheckingError(Exception): ... + + +class TypeCheckingService: + def __init__(self, target_folder: str) -> None: + self.target_folder = target_folder + self.checking_command = ["mypy", "--disallow-untyped-defs", self.target_folder] + self.target_typing_errors: list[MypyTypeCheckingError] = [ + MypyTypeCheckingError.NoUntypedDefs + ] + + def type_checking(self) -> Optional[TypeCheckingError]: + logger.info("[MYPY TYPE CHECKING] Mypy checking types for projects...") + # Run mypy and capture stdout and stderr + result = subprocess.run( + self.checking_command, + capture_output=True, # Captures both stdout and stderr + text=True, # Ensures output is returned as a string + check=False, # Avoids raising an exception on non-zero exit code + ) + std_err_lines = result.stderr.split("\n") + std_out_lines = result.stdout.split("\n") + message: str = "" + for line in [*std_err_lines, *std_out_lines]: + if len(line) != 0: + logger.debug(f"[MYPY TYPE CHECKING] {line}") + for error in self.target_typing_errors: + if error in line: + error_message = f"\n{line}" + message += error_message + if len(message) == 0: + return None + return TypeCheckingError(message) diff --git a/py_spring_core/core/application/application_config.py b/py_spring_core/core/application/application_config.py index ba82ec1..8816185 100644 --- a/py_spring_core/core/application/application_config.py +++ b/py_spring_core/core/application/application_config.py @@ -1,8 +1,8 @@ +from enum import Enum + from pydantic import BaseModel, ConfigDict, Field -from py_spring_core.commons.json_config_repository import ( - JsonConfigRepository, -) +from py_spring_core.commons.json_config_repository import JsonConfigRepository from py_spring_core.core.application.loguru_config import LoguruConfig @@ -21,6 +21,19 @@ class ServerConfig(BaseModel): enabled: bool = Field(default=True) +class ShutdownConfig(BaseModel): + """ + Represents the configuration for graceful shutdown. + + Attributes: + timeout_seconds: The maximum time in seconds to wait for graceful shutdown before forcing termination. + enabled: A boolean flag indicating whether graceful shutdown timeout is enabled. + """ + + timeout_seconds: float = Field(default=30.0, description="Timeout in seconds for graceful shutdown") + enabled: bool = Field(default=True, description="Whether graceful shutdown timeout is enabled") + + class ApplicationConfig(BaseModel): """ Represents the configuration for the application. @@ -31,14 +44,17 @@ class ApplicationConfig(BaseModel): sqlalchemy_database_uri: The URI for the SQLAlchemy database connection. properties_file_path: The file path for the application properties. model_file_postfix_patterns: A list of file name patterns for model (for table creation) files. + shutdown_config: The configuration for graceful shutdown. """ model_config = ConfigDict(protected_namespaces=()) app_src_target_dir: str + exclude_file_patterns: list[str] = Field(default_factory=lambda: [r".*/models\.py$"]) server_config: ServerConfig properties_file_path: str loguru_config: LoguruConfig + shutdown_config: ShutdownConfig = Field(default_factory=ShutdownConfig) class ApplicationConfigRepository(JsonConfigRepository[ApplicationConfig]): diff --git a/py_spring_core/core/application/commons.py b/py_spring_core/core/application/commons.py index 60175df..3d2ebab 100644 --- a/py_spring_core/core/application/commons.py +++ b/py_spring_core/core/application/commons.py @@ -1,7 +1,6 @@ -from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection +from py_spring_core.core.entities.component.component import Component from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.properties.properties import Properties - AppEntities = Component | RestController | BeanCollection | Properties diff --git a/py_spring_core/core/application/context/application_context.py b/py_spring_core/core/application/context/application_context.py index 9ead72e..938aeaf 100644 --- a/py_spring_core/core/application/context/application_context.py +++ b/py_spring_core/core/application/context/application_context.py @@ -1,325 +1,690 @@ +from abc import ABC from inspect import isclass -from typing import Callable, Mapping, Optional, Type, TypeVar, cast +from typing import ( + Annotated, + Any, + Callable, + Mapping, + Optional, + Type, + TypeVar, + cast, + get_args, + get_origin, +) +from fastapi import FastAPI from loguru import logger from pydantic import BaseModel +import py_spring_core.core.utils as framework_utils from py_spring_core.core.application.commons import AppEntities from py_spring_core.core.application.context.application_context_config import ( ApplicationContextConfig, ) -from py_spring_core.core.entities.bean_collection import ( +from py_spring_core.core.entities.bean_collection.bean_collection import ( BeanCollection, BeanConflictError, + BeanView, InvalidBeanError, ) -from py_spring_core.core.entities.component import Component, ComponentScope +from py_spring_core.core.entities.component.component import Component, ComponentScope from py_spring_core.core.entities.controllers.rest_controller import RestController -from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider from py_spring_core.core.entities.properties.properties import Properties from py_spring_core.core.entities.properties.properties_loader import _PropertiesLoader - T = TypeVar("T", bound=AppEntities) PT = TypeVar("PT", bound=Properties) -class ComponentNotFoundError(Exception): ... +class ComponentNotFoundError(Exception): + """Raised when a component is not found in the application context.""" + + pass + +class InvalidDependencyError(Exception): + """Raised when a dependency is invalid or not found in the application context.""" -class InvalidDependencyError(Exception): ... + pass class ApplicationContextView(BaseModel): + """View model for application context state.""" + config: ApplicationContextConfig - component_cls_container: list[str] - singleton_component_instance_container: list[str] + component_classes: list[str] + component_instances: list[str] -class ApplicationContext: - """ - The `ApplicationContext` class is the main entry point for the application's context management. - It is responsible for: - 1. Registering and managing the lifecycle of components, controllers, bean collections, and properties. - 2. Providing methods to retrieve instances of registered components, beans, and properties. - 3. Initializing the Inversion of Control (IoC) container by creating singleton instances of registered components. - 4. Injecting dependencies for registered components and controllers. - The `ApplicationContext` class is designed to follow the Singleton design pattern, ensuring that there is a single instance of the application context throughout the application's lifetime. - """ +class ContainerManager: + """Manages containers for different types of entities in the application context.""" - def __init__(self, config: ApplicationContextConfig) -> None: - self.all_file_paths: set[str] = set() + def __init__(self): + self.component_classes: dict[str, Type[Component]] = {} + self.controller_classes: dict[str, Type[RestController]] = {} + self.component_instances: dict[str, Component] = {} + + self.bean_collection_classes: dict[str, Type[BeanCollection]] = {} + self.bean_instances: dict[str, object] = {} + + self.properties_classes: dict[str, Type[Properties]] = {} + self.properties_instances: dict[str, Properties] = {} + + def is_entity_in_container(self, entity_cls: Type[AppEntities]) -> bool: + """Check if an entity class is registered in any container.""" + cls_name = entity_cls.__name__ + return ( + cls_name in self.component_classes + or cls_name in self.controller_classes + or cls_name in self.bean_collection_classes + or cls_name in self.properties_classes + ) + + +class DependencyInjector: + """Handles dependency injection for entities in the application context.""" + + def __init__(self, container_manager: ContainerManager): + self.container_manager = container_manager self.primitive_types = (bool, str, int, float, type(None)) + self._app_context: Optional["ApplicationContext"] = None + + def _extract_qualifier_from_annotation( + self, annotated_type: Type + ) -> tuple[Type, Optional[str]]: + """Extract the actual type and qualifier from an Annotated type.""" + qualifier = None + if get_origin(annotated_type) is Annotated: + type_args = get_args(annotated_type) + annotated_type = type_args[0] + if len(type_args) > 1: + qualifier = type_args[1] + return annotated_type, qualifier + + def _inject_properties_dependency( + self, + entity: Type[AppEntities], + attr_name: str, + properties_cls: Type[Properties], + ) -> bool: + """Inject a properties dependency into an entity.""" + if self._app_context is None: + return False + + optional_properties = self._app_context.get_properties(properties_cls) + if optional_properties is None: + raise TypeError( + f"[PROPERTIES INJECTION ERROR] Properties: {properties_cls.get_name()} " + f"is not found in properties file for class: {properties_cls.get_name()} " + f"with key: {properties_cls.get_key()}" + ) + setattr(entity, attr_name, optional_properties) + return True + + def _try_inject_entity_dependency( + self, + entity: Type[AppEntities], + attr_name: str, + entity_cls: Type[AppEntities], + qualifier: Optional[str], + ) -> bool: + """Try to inject an entity dependency using available getters.""" + if self._app_context is None: + return False + + entity_getters: list[ + Callable[[Type[AppEntities], Optional[str]], Optional[AppEntities]] + ] = [self._app_context.get_component, self._app_context.get_bean] + + for getter in entity_getters: + optional_entity = getter(entity_cls, qualifier) + if optional_entity is not None: + setattr(entity, attr_name, optional_entity) + logger.success( + f"[DEPENDENCY INJECTION SUCCESS] Inject dependency for {entity_cls.__name__} " + f"in attribute: {attr_name}" + ) + return True + return False + + def inject_dependencies(self, entity: Type[AppEntities]) -> None: + """Inject dependencies for a given entity based on its annotations.""" + for attr_name, annotated_type in entity.__annotations__.items(): + entity_cls, qualifier = self._extract_qualifier_from_annotation( + annotated_type + ) - self.config = config - self.component_cls_container: dict[str, Type[Component]] = {} - self.controller_cls_container: dict[str, Type[RestController]] = {} - self.singleton_component_instance_container: dict[str, Component] = {} + # Skip primitive types + if entity_cls in self.primitive_types: + logger.warning( + f"[DEPENDENCY INJECTION SKIPPED] Skip inject dependency for attribute: {attr_name} " + f"with dependency: {entity_cls.__name__} because it is primitive type" + ) + continue - self.bean_collection_cls_container: dict[str, Type[BeanCollection]] = {} - self.singleton_bean_instance_container: dict[str, object] = {} + # Skip non-class types + if not isclass(entity_cls): + continue - self.properties_cls_container: dict[str, Type[Properties]] = {} - self.singleton_properties_instance_container: dict[str, Properties] = {} - self.providers: list[EntityProvider] = [] + # Handle Properties injection + if issubclass(entity_cls, Properties): + if self._inject_properties_dependency(entity, attr_name, entity_cls): + continue - def set_all_file_paths(self, all_file_paths: set[str]) -> None: - self.all_file_paths = all_file_paths + # Try to inject entity dependency + if self._try_inject_entity_dependency( + entity, attr_name, entity_cls, qualifier + ): + continue - def _create_properties_loader(self) -> _PropertiesLoader: - return _PropertiesLoader( - self.config.properties_path, list(self.properties_cls_container.values()) - ) + # If we get here, injection failed + error_message = ( + f"[DEPENDENCY INJECTION FAILED] Fail to inject dependency for attribute: {attr_name} " + f"with dependency: {entity_cls.__name__} with qualifier: {qualifier}, " + f"consider register such dependency with Component decorator" + ) + logger.critical(error_message) + raise ValueError(error_message) - def as_view(self) -> ApplicationContextView: - return ApplicationContextView( - config=self.config, - component_cls_container=list(self.component_cls_container.keys()), - singleton_component_instance_container=list( - self.singleton_component_instance_container.keys() - ), - ) - def get_component(self, component_cls: Type[T]) -> Optional[T]: +class ComponentManager: + """Manages component registration and instantiation.""" + + def __init__(self, container_manager: ContainerManager): + self.container_manager = container_manager + + def _determine_target_cls_name( + self, component_cls: Type[T], qualifier: Optional[str] + ) -> str: + """ + Determine the target class name for a given component class. + + Args: + component_cls: The component class to determine the name for + qualifier: Optional qualifier to use directly + + Returns: + The target class name + + Raises: + ValueError: If abstract class has no subclasses + """ + if qualifier is not None: + return qualifier + + # If it's not an ABC, return its name directly + if not issubclass(component_cls, ABC): + return component_cls.get_name() + + # If it's an ABC but has implementations, return its name directly + if not component_cls.__abstractmethods__: + return component_cls.get_name() + + # For abstract classes that need implementations + subclasses = component_cls.__subclasses__() + if len(subclasses) == 0: + raise ValueError( + f"[ABSTRACT CLASS ERROR] Abstract class {component_cls.__name__} has no subclasses" + ) + + # Fall back to first subclass if no primary component exists + return subclasses[0].get_name() + + def register_component(self, component_cls: Type[Component]) -> None: + """Register a component class in the container.""" if not issubclass(component_cls, Component): - return + raise TypeError( + f"[COMPONENT REGISTRATION ERROR] Component: {component_cls} " + f"is not a subclass of Component" + ) component_cls_name = component_cls.get_name() - if component_cls_name not in self.component_cls_container: + existing_component = self.container_manager.component_classes.get( + component_cls_name + ) + + # Check if it's the same component to avoid duplicate registration + if ( + existing_component + and existing_component.__name__ == component_cls.__name__ + and existing_component == component_cls + ): return + self.container_manager.component_classes[component_cls_name] = ( + component_cls + ) + + def get_component( + self, component_cls: Type[T], qualifier: Optional[str] + ) -> Optional[T]: + """Get a component instance by class and optional qualifier.""" + if not issubclass(component_cls, (Component, ABC)): + return None + + target_cls_name = self._determine_target_cls_name(component_cls, qualifier) + + if target_cls_name not in self.container_manager.component_classes: + return None + scope = component_cls.get_scope() match scope: case ComponentScope.Singleton: - optional_instance = self.singleton_component_instance_container.get( - component_cls_name + return cast( + T, + self.container_manager.component_instances.get( + target_cls_name + ), ) - return optional_instance # type: ignore - case ComponentScope.Prototype: - prototype_instance = component_cls() - return prototype_instance - - def is_within_context(self, _cls: Type[AppEntities]) -> bool: - cls_name = _cls.__name__ - is_within_component = cls_name in self.component_cls_container - is_within_controller = cls_name in self.controller_cls_container - is_within_bean_collection = cls_name in self.bean_collection_cls_container - is_within_properties = cls_name in self.properties_cls_container - return ( - is_within_component - or is_within_controller - or is_within_bean_collection - or is_within_properties - ) + return cast(T, component_cls()) + + def _init_singleton_component( + self, component_cls: Type[Component], component_cls_name: str + ) -> Optional[Component]: + """Initialize a singleton component instance.""" + try: + return component_cls() + except Exception as error: + if "Can't instantiate abstract class" in str(error): + logger.warning( + f"[INITIALIZING SINGLETON COMPONENT ERROR] Skip initializing singleton component: " + f"{component_cls_name} because it is an abstract class" + ) + return None + logger.error( + f"[INITIALIZING SINGLETON COMPONENT ERROR] Error initializing singleton component: " + f"{component_cls_name} with error: {error}" + ) + raise error + + def _get_abstract_class_component_subclasses( + self, component_cls: Type[ABC] + ) -> list[Type[Component]]: + """Get all Component subclasses of an abstract class.""" + return [ + subclass + for subclass in component_cls.__subclasses__() + if issubclass(subclass, Component) + ] - def get_bean(self, object_cls: Type[T]) -> Optional[T]: - bean_name = object_cls.__name__ - if bean_name not in self.singleton_bean_instance_container: - return + def _init_abstract_component_subclasses(self, component_cls: Type[ABC]) -> None: + """Initialize singleton instances for abstract component subclasses.""" + component_classes = self._get_abstract_class_component_subclasses(component_cls) - optional_instance = self.singleton_bean_instance_container.get(bean_name) - return optional_instance # type: ignore + for subclass_component_cls in component_classes: + self.register_component(subclass_component_cls) - def get_properties(self, properties_cls: Type[PT]) -> Optional[PT]: - properties_cls_name = properties_cls.get_key() - if properties_cls_name not in self.properties_cls_container: - return - optional_instance = cast( - PT, self.singleton_properties_instance_container.get(properties_cls_name) - ) - return optional_instance + # Check for unimplemented abstract methods + unimplemented_methods = framework_utils.get_unimplemented_abstract_methods( + subclass_component_cls + ) + if unimplemented_methods: + methods_str = ", ".join(unimplemented_methods) + message = ( + f"[ABSTRACT CLASS COMPONENT INITIALIZING SINGLETON COMPONENT] " + f"Unable to initialize singleton component: {subclass_component_cls.get_name()} " + f"because it has unimplemented abstract methods: {methods_str}" + ) + logger.error(message) + raise ValueError(message) - def register_component(self, component_cls: Type[Component]) -> None: - if not issubclass(component_cls, Component): - raise TypeError( - f"[COMPONENT REGISTRATION ERROR] Component: {component_cls} is not a subclass of Component" + logger.debug( + f"[ABSTRACT CLASS COMPONENT INITIALIZING SINGLETON COMPONENT] " + f"Init singleton component: {subclass_component_cls.get_name()}" ) - component_cls_name = component_cls.get_name() - self.component_cls_container[component_cls_name] = component_cls + instance = self._init_singleton_component( + subclass_component_cls, subclass_component_cls.get_name() + ) + if instance is not None: + self.container_manager.component_instances[ + subclass_component_cls.get_name() + ] = instance - def register_controller(self, controller_cls: Type[RestController]) -> None: - if not issubclass(controller_cls, RestController): - raise TypeError( - f"[CONTROLLER REGISTRATION ERROR] Controller: {controller_cls} is not a subclass of RestController" + def init_singleton_components(self) -> None: + """Initialize all singleton components in the container.""" + for ( + component_cls_name, + component_cls, + ) in self.container_manager.component_classes.items(): + if component_cls.get_scope() != ComponentScope.Singleton: + continue + + logger.debug( + f"[INITIALIZING SINGLETON COMPONENT] Init singleton component: {component_cls_name}" ) - controller_cls_name = controller_cls.get_name() - self.controller_cls_container[controller_cls_name] = controller_cls + if issubclass(component_cls, ABC): + self._init_abstract_component_subclasses(component_cls) + else: + instance = self._init_singleton_component( + component_cls, component_cls_name + ) + if instance is not None: + self.container_manager.component_instances[ + component_cls_name + ] = instance + + +class BeanManager: + """Manages bean collection registration and instantiation.""" + + def __init__( + self, + container_manager: ContainerManager, + dependency_injector: DependencyInjector, + ): + self.container_manager = container_manager + self.dependency_injector = dependency_injector def register_bean_collection(self, bean_cls: Type[BeanCollection]) -> None: + """Register a bean collection class in the container.""" if not issubclass(bean_cls, BeanCollection): raise TypeError( - f"[BEAN COLLECTION REGISTRATION ERROR] BeanCollection: {bean_cls} is not a subclass of BeanCollection" + f"[BEAN COLLECTION REGISTRATION ERROR] BeanCollection: {bean_cls} " + f"is not a subclass of BeanCollection" ) bean_name = bean_cls.get_name() - self.bean_collection_cls_container[bean_name] = bean_cls - - def register_entity_provider(self, provider: EntityProvider) -> None: - self.providers.append(provider) - for entity_cls in provider.get_entities(): - if issubclass(entity_cls, Component): - self.register_component(entity_cls) - elif issubclass(entity_cls, RestController): - self.register_controller(entity_cls) - elif issubclass(entity_cls, BeanCollection): - self.register_bean_collection(entity_cls) - elif issubclass(entity_cls, Properties): - self.register_properties(entity_cls) - - - + self.container_manager.bean_collection_classes[bean_name] = bean_cls + + def get_bean( + self, object_cls: Type[T], qualifier: Optional[str] = None + ) -> Optional[T]: + """Get a bean instance by class and optional qualifier.""" + bean_name = object_cls.__name__ + if bean_name not in self.container_manager.bean_instances: + return None + + return cast( + T, self.container_manager.bean_instances.get(bean_name) + ) + + def _inject_bean_collection_dependencies( + self, bean_collection_cls: Type[BeanCollection] + ) -> None: + """Inject dependencies for a bean collection.""" + logger.info( + f"[BEAN COLLECTION DEPENDENCY INJECTION] Injecting dependencies for {bean_collection_cls.get_name()}" + ) + self.dependency_injector.inject_dependencies(bean_collection_cls) + + def _validate_bean_view(self, view: BeanView, collection_name: str) -> None: + """Validate a bean view before adding it to the container.""" + if view.bean_name in self.container_manager.bean_instances: + raise BeanConflictError( + f"[BEAN CONFLICTS] Bean: {view.bean_name} already exists under collection: {collection_name}" + ) + + if not view.is_valid_bean(): + raise InvalidBeanError( + f"[INVALID BEAN] Bean name from bean creation func return type: {view.bean_name} " + f"does not match the bean object class name: {view.bean.__class__.__name__}" + ) + + def init_singleton_beans(self) -> None: + """Initialize all singleton beans from registered bean collections.""" + for ( + bean_collection_cls_name, + bean_collection_cls, + ) in self.container_manager.bean_collection_classes.items(): + logger.debug( + f"[INITIALIZING SINGLETON BEAN] Init singleton bean: {bean_collection_cls_name}" + ) + + collection = bean_collection_cls() + self._inject_bean_collection_dependencies(bean_collection_cls) + + bean_views = collection.scan_beans() + for view in bean_views: + self._validate_bean_view(view, collection.get_name()) + self.container_manager.bean_instances[ + view.bean_name + ] = view.bean + + +class PropertiesManager: + """Manages properties registration and loading.""" + + def __init__( + self, container_manager: ContainerManager, config: ApplicationContextConfig + ): + self.container_manager = container_manager + self.config = config + def register_properties(self, properties_cls: Type[Properties]) -> None: + """Register a properties class in the container.""" if not issubclass(properties_cls, Properties): raise TypeError( - f"[PROPERTIES REGISTRATION ERROR] Properties: {properties_cls} is not a subclass of Properties" + f"[PROPERTIES REGISTRATION ERROR] Properties: {properties_cls} " + f"is not a subclass of Properties" ) + properties_name = properties_cls.get_key() - self.properties_cls_container[properties_name] = properties_cls + self.container_manager.properties_classes[properties_name] = ( + properties_cls + ) - def get_controller_instances(self) -> list[RestController]: - return [_cls() for _cls in self.controller_cls_container.values()] + def get_properties(self, properties_cls: Type[PT]) -> Optional[PT]: + """Get a properties instance by class.""" + properties_cls_name = properties_cls.get_key() + if properties_cls_name not in self.container_manager.properties_classes: + return None - def get_singleton_component_instances(self) -> list[Component]: - return [_cls for _cls in self.singleton_component_instance_container.values()] + return cast( + PT, + self.container_manager.properties_instances.get( + properties_cls_name + ), + ) - def get_singleton_bean_instances(self) -> list[object]: - return [_cls for _cls in self.singleton_bean_instance_container.values()] + def _create_properties_loader(self) -> _PropertiesLoader: + """Create a properties loader instance.""" + return _PropertiesLoader( + self.config.properties_path, + list(self.container_manager.properties_classes.values()), + ) def load_properties(self) -> None: + """Load all registered properties from configuration files.""" properties_loader = self._create_properties_loader() properties_instance_dict = properties_loader.load_properties() - for properties_key, properties_cls in self.properties_cls_container.items(): - if properties_key in self.singleton_properties_instance_container: + + for ( + properties_key, + properties_cls, + ) in self.container_manager.properties_classes.items(): + if ( + properties_key + in self.container_manager.properties_instances + ): continue logger.debug( f"[INITIALIZING SINGLETON PROPERTIES] Init singleton properties: {properties_key}" ) + optional_properties = properties_instance_dict.get(properties_key) if optional_properties is None: raise TypeError( - f"[PROPERTIES INITIALIZATION ERROR] Properties: {properties_key} is not found in properties file for class: {properties_cls.get_name()} with key: {properties_cls.get_key()}" + f"[PROPERTIES INITIALIZATION ERROR] Properties: {properties_key} " + f"is not found in properties file for class: {properties_cls.get_name()} " + f"with key: {properties_cls.get_key()}" ) - self.singleton_properties_instance_container[properties_key] = ( - optional_properties - ) + + self.container_manager.properties_instances[ + properties_key + ] = optional_properties + + # Update the global properties loader reference _PropertiesLoader.optional_loaded_properties = ( - self.singleton_properties_instance_container + self.container_manager.properties_instances ) - def init_ioc_container(self) -> None: - """ - Initializes the IoC (Inversion of Control) container by creating singleton instances of all registered components. - This method iterates through the registered component classes in the `component_cls_container` dictionary. - For each component class with a `Singleton` scope, it creates an instance of the component and stores it in the `singleton_component_instance_container` dictionary. - This ensures that subsequent calls to `get_component()` for singleton components will return the same instance, as required by the Singleton design pattern. - """ - # for Components - for component_cls_name, component_cls in self.component_cls_container.items(): - if component_cls.get_scope() != ComponentScope.Singleton: - continue - logger.debug( - f"[INITIALIZING SINGLETON COMPONENT] Init singleton component: {component_cls_name}" - ) - instance = component_cls() - self.singleton_component_instance_container[component_cls_name] = instance +class ApplicationContext: + """ + The main entry point for the application's context management. - # for Bean - for ( - bean_collection_cls_name, - bean_collection_cls, - ) in self.bean_collection_cls_container.items(): - logger.debug( - f"[INITIALIZING SINGLETON BEAN] Init singleton bean: {bean_collection_cls_name}" - ) - collection = bean_collection_cls() - # before injecting_bean_collection deps and scanning beans, make sure properties is loaded in _PropertiesLoader by calling load_properties inside Application class - self._inject_dependencies_for_bean_collection(bean_collection_cls) - bean_views = collection.scan_beans() - for view in bean_views: - if view.bean_name in self.singleton_bean_instance_container: - raise BeanConflictError( - f"[BEAN CONFLICTS] Bean: {view.bean_name} already exists under collection: {collection.get_name()}" - ) - if not view.is_valid_bean(): - raise InvalidBeanError( - f"[INVALID BEAN] Bean name from bean creation func return type: {view.bean_name} does not match the bean object class name: {view.bean.__class__.__name__}" - ) - self.singleton_bean_instance_container[view.bean_name] = view.bean - - def _inject_entity_dependencies(self, entity: Type[AppEntities]) -> None: - for attr_name, annotated_entity_cls in entity.__annotations__.items(): - is_injected: bool = False - if annotated_entity_cls in self.primitive_types: - logger.warning( - f"[DEPENDENCY INJECTION SKIPPED] Skip inject dependency for attribute: {attr_name} with dependency: {annotated_entity_cls.__name__} because it is primitive type" - ) - continue + This class is responsible for: + 1. Registering and managing the lifecycle of components, controllers, bean collections, and properties. + 2. Providing methods to retrieve instances of registered components, beans, and properties. + 3. Initializing the Inversion of Control (IoC) container by creating singleton instances of registered components. + 4. Injecting dependencies for registered components and controllers. - if not isclass(annotated_entity_cls): - continue + The ApplicationContext class is designed to follow the Singleton design pattern, ensuring that there is + a single instance of the application context throughout the application's lifetime. + """ - if issubclass(annotated_entity_cls, Properties): - optional_properties = self.get_properties(annotated_entity_cls) - if optional_properties is None: - raise TypeError( - f"[PROPERTIES INJECTION ERROR] Properties: {annotated_entity_cls.get_name()} is not found in properties file for class: {annotated_entity_cls.get_name()} with key: {annotated_entity_cls.get_key()}" - ) - setattr(entity, attr_name, optional_properties) - continue + def __init__(self, config: ApplicationContextConfig, server: FastAPI) -> None: + self.server = server + self.config = config + self.all_file_paths: set[str] = set() + self.providers: list[EntityProvider] = [] - entity_getters: list[Callable] = [self.get_component, self.get_bean] + # Initialize managers + self.container_manager = ContainerManager() + self.dependency_injector = DependencyInjector(self.container_manager) + self.component_manager = ComponentManager(self.container_manager) + self.bean_manager = BeanManager( + self.container_manager, self.dependency_injector + ) + self.properties_manager = PropertiesManager(self.container_manager, config) - for getter in entity_getters: - optional_entity = getter(annotated_entity_cls) - if optional_entity is not None: - setattr(entity, attr_name, optional_entity) - is_injected = True - break + # Set app context reference for dependency injection + self.dependency_injector._app_context = self - if is_injected: - logger.success( - f"[DEPENDENCY INJECTION SUCCESS FROM COMPONENT CONTAINER] Inject dependency for {annotated_entity_cls.__name__} in attribute: {attr_name} with dependency: {annotated_entity_cls.__name__} singleton instance" - ) - continue + def set_all_file_paths(self, all_file_paths: set[str]) -> None: + """Set the collection of all file paths in the application.""" + self.all_file_paths = all_file_paths - error_message = f"[DEPENDENCY INJECTION FAILED] Fail to inject dependency for attribute: {attr_name} with dependency: {annotated_entity_cls.__name__}, consider register such depency with Compoent decorator" - logger.critical(error_message) - raise ValueError(error_message) + def as_view(self) -> ApplicationContextView: + """Create a view model of the application context state.""" + return ApplicationContextView( + config=self.config, + component_classes=list( + self.container_manager.component_classes.keys() + ), + component_instances=list( + self.container_manager.component_instances.keys() + ), + ) - def _inject_dependencies_for_bean_collection( - self, bean_collection_cls: Type[BeanCollection] - ) -> None: - logger.info( - f"[BEAN COLLECTION DEPENDENCY INJECTION] Injecting dependencies for {bean_collection_cls.get_name()}" + # Component management methods + def get_component( + self, component_cls: Type[T], qualifier: Optional[str] = None + ) -> Optional[T]: + """Get a component instance by class and optional qualifier.""" + return self.component_manager.get_component(component_cls, qualifier) + + def register_component(self, component_cls: Type[Component]) -> None: + """Register a component class in the application context.""" + self.component_manager.register_component(component_cls) + + # Bean management methods + def get_bean( + self, object_cls: Type[T], qualifier: Optional[str] = None + ) -> Optional[T]: + """Get a bean instance by class and optional qualifier.""" + return self.bean_manager.get_bean(object_cls, qualifier) + + def register_bean_collection(self, bean_cls: Type[BeanCollection]) -> None: + """Register a bean collection class in the application context.""" + self.bean_manager.register_bean_collection(bean_cls) + + # Properties management methods + def get_properties(self, properties_cls: Type[PT]) -> Optional[PT]: + """Get a properties instance by class.""" + return self.properties_manager.get_properties(properties_cls) + + def register_properties(self, properties_cls: Type[Properties]) -> None: + """Register a properties class in the application context.""" + self.properties_manager.register_properties(properties_cls) + + def load_properties(self) -> None: + """Load all registered properties from configuration files.""" + self.properties_manager.load_properties() + + # Controller management methods + def register_controller(self, controller_cls: Type[RestController]) -> None: + """Register a controller class in the application context.""" + if not issubclass(controller_cls, RestController): + raise TypeError( + f"[CONTROLLER REGISTRATION ERROR] Controller: {controller_cls} " + f"is not a subclass of RestController" + ) + + controller_cls_name = controller_cls.get_name() + self.container_manager.controller_classes[controller_cls_name] = ( + controller_cls + ) + + def get_controller_instances(self) -> list[RestController]: + """Get all controller instances.""" + return [ + cls() for cls in self.container_manager.controller_classes.values() + ] + + def get_singleton_component_instances(self) -> list[Component]: + """Get all singleton component instances.""" + return list( + self.container_manager.component_instances.values() ) - self._inject_entity_dependencies(bean_collection_cls) + + def get_singleton_bean_instances(self) -> list[object]: + """Get all singleton bean instances.""" + return list(self.container_manager.bean_instances.values()) + + def is_within_context(self, entity_cls: Type[AppEntities]) -> bool: + """Check if an entity class is registered in the application context.""" + return self.container_manager.is_entity_in_container(entity_cls) + + def init_ioc_container(self) -> None: + """ + Initialize the IoC (Inversion of Control) container. + + This method creates singleton instances of all registered components and beans, + ensuring that subsequent calls to get_component() for singleton components + will return the same instance, as required by the Singleton design pattern. + """ + # Initialize singleton components + self.component_manager.init_singleton_components() + + # Initialize singleton beans + self.bean_manager.init_singleton_beans() + + def inject_dependencies_for_external_object(self, object: Type[Any]) -> None: + """Inject dependencies for an external object.""" + self.dependency_injector.inject_dependencies(object) def inject_dependencies_for_app_entities(self) -> None: + """Inject dependencies for all registered app entities.""" containers: list[Mapping[str, Type[AppEntities]]] = [ - self.component_cls_container, - self.controller_cls_container, + self.container_manager.component_classes, + self.container_manager.controller_classes, ] for container in containers: - for _cls_name, _cls in container.items(): - self._inject_entity_dependencies(_cls) + for cls_name, cls in container.items(): + self.dependency_injector.inject_dependencies(cls) def _validate_entity_provider_dependencies(self, provider: EntityProvider) -> None: + """Validate dependencies for a single entity provider.""" for dependency in provider.depends_on: if not issubclass(dependency, AppEntities): error = f"[INVALID DEPENDENCY] Invalid dependency {dependency.__name__} in {provider.__class__.__name__}" logger.error(error) raise InvalidDependencyError(error) + if not self.is_within_context(dependency): error = f"[INVALID DEPENDENCY] Dependency {dependency.__name__} not found in the application context" logger.error(error) raise InvalidDependencyError(error) def validate_entity_providers(self) -> None: + """Validate all entity providers in the application context.""" for provider in self.providers: self._validate_entity_provider_dependencies(provider) diff --git a/py_spring_core/core/application/context/application_context_type_checker.py b/py_spring_core/core/application/context/application_context_type_checker.py deleted file mode 100644 index d10500e..0000000 --- a/py_spring_core/core/application/context/application_context_type_checker.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Any, Iterable, Mapping, Type - -from loguru import logger -from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.application.context.application_context import ApplicationContext -from py_spring_core.core.utils import TypeHintError, check_type_hints_for_class - - -class ApplicationContextTypeChecker: - def __init__( - self, app_context: ApplicationContext, - skip_class_attrs: list[str], - target_classes: Iterable[Type[Any]], - skipped_classes: Iterable[Type[Any]] - ) -> None: - self.app_context = app_context - self.skip_class_attrs = skip_class_attrs - self.target_classes = target_classes - self.skipped_classes = skipped_classes - - - def check_type_hints_for_context(self, ctx: ApplicationContext) -> None: - containers: list[Mapping[str, Type[AppEntities]]] = [ - ctx.component_cls_container, - ctx.controller_cls_container, - ctx.bean_collection_cls_container, - ctx.properties_cls_container, - ] - for container in containers: - for _cls in container.values(): - if issubclass(_cls, tuple(self.skipped_classes)): - continue - if _cls not in self.target_classes: - try: - check_type_hints_for_class(_cls, skip_attrs=self.skip_class_attrs) - except TypeHintError as error: - logger.warning(f"Type hint error for class {_cls.__name__}: {error}") - raise error - except NameError as error: - ... \ No newline at end of file diff --git a/py_spring_core/core/application/loguru_config.py b/py_spring_core/core/application/loguru_config.py index f291ff2..3efc541 100644 --- a/py_spring_core/core/application/loguru_config.py +++ b/py_spring_core/core/application/loguru_config.py @@ -14,6 +14,11 @@ class LogLevel(str, Enum): CRITICAL = "CRITICAL" +class LogFormat(str, Enum): + TEXT = "text" + JSON = "json" + + class LoguruConfig(BaseModel): log_format: str = ( "{time:YYYY-MM-DD HH:mm:ss.SSS} | " @@ -27,3 +32,4 @@ class LoguruConfig(BaseModel): log_file_path: Optional[str] = "./logs/app.log" enable_backtrace: bool = True enable_diagnose: bool = True + format: LogFormat = LogFormat.TEXT diff --git a/py_spring_core/core/application/py_spring_application.py b/py_spring_core/core/application/py_spring_application.py index aa82780..1d18b17 100644 --- a/py_spring_core/core/application/py_spring_application.py +++ b/py_spring_core/core/application/py_spring_application.py @@ -1,40 +1,53 @@ +import logging import os -from typing import Any, Callable, Iterable, Type +from typing import Any, Callable, Iterable, Optional, Type, TypeVar import uvicorn from fastapi import APIRouter, FastAPI from loguru import logger -from pydantic import BaseModel, ConfigDict -from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.application.context.application_context_type_checker import ( - ApplicationContextTypeChecker, -) -from py_spring_core.core.entities.entity_provider import EntityProvider from py_spring_core.commons.class_scanner import ClassScanner from py_spring_core.commons.config_file_template_generator.config_file_template_generator import ( ConfigFileTemplateGenerator, ) from py_spring_core.commons.file_path_scanner import FilePathScanner -from py_spring_core.core.application.application_config import ApplicationConfigRepository +from py_spring_core.commons.type_checking_service import TypeCheckingService +from py_spring_core.core.application.application_config import ( + ApplicationConfigRepository, +) +from py_spring_core.core.application.commons import AppEntities from py_spring_core.core.application.context.application_context import ( ApplicationContext, ) from py_spring_core.core.application.context.application_context_config import ( ApplicationContextConfig, ) -from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component, ComponentLifeCycle +from py_spring_core.core.application.loguru_config import LogFormat +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection +from py_spring_core.core.entities.component.component import Component, ComponentLifeCycle from py_spring_core.core.entities.controllers.rest_controller import RestController +from py_spring_core.core.entities.controllers.route_mapping import RouteMapping +from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.entities.middlewares.middleware_registry import ( + MiddlewareConfiguration, + MiddlewareRegistry, +) from py_spring_core.core.entities.properties.properties import Properties +from py_spring_core.core.interfaces.application_context_required import ( + ApplicationContextRequired, +) +from py_spring_core.core.interfaces.graceful_shutdown_handler import GracefulShutdownHandler +from py_spring_core.core.interfaces.single_inheritance_required import SingleInheritanceRequired +from py_spring_core.event.application_event_handler_registry import ( + ApplicationEventHandlerRegistry, +) +from py_spring_core.event.application_event_publisher import ApplicationEventPublisher - -class ApplicationFileGroups(BaseModel): - model_config = ConfigDict(protected_namespaces=()) - class_files: set[str] - model_files: set[str] +import py_spring_core.core.utils as framework_utils +SingleInheritanceRequiredT = TypeVar("SingleInheritanceRequiredT", bound=SingleInheritanceRequired) class PySpringApplication: """ The PySpringApplication class is the main entry point for the PySpring application. @@ -42,8 +55,6 @@ class PySpringApplication: The class performs the following key tasks: - Initializes the application from a configuration file path - - Scans the application source directory for Python files and groups them into class files and model files - - Dynamically imports the model modules and creates SQLModel tables - Registers application entities (components, controllers, bean collections, properties) with the application context - Initializes the application context and injects dependencies - Handles the lifecycle of singleton components @@ -82,8 +93,10 @@ def __init__( self.app_context_config = ApplicationContextConfig( properties_path=self.app_config.properties_file_path ) - self.app_context = ApplicationContext(config=self.app_context_config) self.fastapi = FastAPI() + self.app_context = ApplicationContext( + config=self.app_context_config, server=self.fastapi + ) self.classes_with_handlers: dict[ Type[AppEntities], Callable[[Type[Any]], None] @@ -93,34 +106,35 @@ def __init__( BeanCollection: self._handle_register_bean_collection, Properties: self._handle_register_properties, } - self.skip_class_attrs = ["Config", "model_config"] - self.app_context_typer_checker = ApplicationContextTypeChecker( - self.app_context, self.skip_class_attrs, self.classes_with_handlers.keys(), - [Properties, BaseModel] + self.type_checking_service = TypeCheckingService( + self.app_config.app_src_target_dir ) + self.shutdown_handler: Optional[GracefulShutdownHandler] = None - def __configure_logging(self): + def _configure_logging(self): """Applies the logging configuration using Loguru.""" config = self.app_config.loguru_config if not config.log_file_path: return + # Use the format field from config which contains the actual format string logger.add( config.log_file_path, - format=config.log_format, level=config.log_level, rotation=config.log_rotation, retention=config.log_retention, + serialize=config.format == LogFormat.JSON, ) + self._configure_uvicorn_logging() - def _scan_classes_for_project(self) -> None: - self.app_class_scanner.scan_classes_for_file_paths() - self.scanned_classes = self.app_class_scanner.get_classes() + def _get_system_managed_classes(self) -> Iterable[Type[Component]]: + return [ApplicationEventPublisher, ApplicationEventHandlerRegistry] - def _register_all_entities_from_providers(self) -> None: - for provider in self.entity_providers: - entities = provider.get_entities() - self._register_app_entities(entities) + def _scan_classes_for_project(self) -> Iterable[Type[object]]: + self.app_class_scanner.scan_classes_for_file_paths( + self.app_config.exclude_file_patterns + ) + return self.app_class_scanner.get_classes() def _register_app_entities(self, classes: Iterable[Type[object]]) -> None: for _cls in classes: @@ -129,12 +143,14 @@ def _register_app_entities(self, classes: Iterable[Type[object]]) -> None: continue handler(_cls) - def _register_entity_providers( + def _get_all_entities_from_entity_providers( self, entity_providers: Iterable[EntityProvider] - ) -> None: + ) -> Iterable[Type[AppEntities]]: + entities: list[Type[AppEntities]] = [] for provider in entity_providers: - self.app_context.register_entity_provider(provider) - provider.set_context(self.app_context) + entities.extend(provider.get_entities()) + + return entities def _handle_register_component(self, _cls: Type[Component]) -> None: self.app_context.register_component(_cls) @@ -165,25 +181,43 @@ def _init_providers(self, providers: Iterable[EntityProvider]) -> None: for provider in providers: provider.provider_init() - def __init_app(self) -> None: - self._scan_classes_for_project() - self._register_all_entities_from_providers() - self._register_app_entities(self.scanned_classes) - self._register_entity_providers(self.entity_providers) - self._check_type_hints() + def _inject_application_context_to_context_required( + self, classes: Iterable[Type[object]] + ) -> None: + for cls in classes: + if not issubclass(cls, ApplicationContextRequired): + continue + cls.set_application_context(self.app_context) + + def _prepare_injected_classes(self) -> Iterable[Type[object]]: + scanned_classes = self._scan_classes_for_project() + system_managed_classes = self._get_system_managed_classes() + provider_entities = self._get_all_entities_from_entity_providers( + self.entity_providers + ) + provider_classes = [provider.__class__ for provider in self.entity_providers] + # providers typically requires app context, so add to classess to inject + classes_to_inject = [ + *scanned_classes, + *system_managed_classes, + *provider_entities, + *provider_classes, + ] + return classes_to_inject + + def _init_app(self) -> None: + classes_to_inject = self._prepare_injected_classes() + self._inject_application_context_to_context_required(classes_to_inject) + self._register_app_entities(classes_to_inject) self.app_context.load_properties() self.app_context.init_ioc_container() self.app_context.inject_dependencies_for_app_entities() self.app_context.set_all_file_paths(self.target_dir_absolute_file_paths) self.app_context.validate_entity_providers() # after injecting all deps, lifecycle (init) can be called - self._init_providers(self.entity_providers) self._handle_singleton_components_life_cycle(ComponentLifeCycle.Init) - def _check_type_hints(self) -> None: - self.app_context_typer_checker.check_type_hints_for_context(self.app_context) - def _handle_singleton_components_life_cycle( self, life_cycle: ComponentLifeCycle ) -> None: @@ -195,27 +229,138 @@ def _handle_singleton_components_life_cycle( case ComponentLifeCycle.Destruction: component.finish_destruction_cycle() - def __init_controllers(self) -> None: + def _init_controllers(self) -> None: controllers = self.app_context.get_controller_instances() for controller in controllers: - controller.register_routes() + name = controller.__class__.__name__ + routes = RouteMapping.routes.get(name, set()) + controller.post_construct() + controller._register_decorated_routes(routes) router = controller.get_router() self.fastapi.include_router(router) - controller.register_middlewares() + logger.debug(f"[CONTROLLER INIT] Controller {name} initialized") + + def _init_external_handler(self, base_class: Type[SingleInheritanceRequiredT]) -> Type[SingleInheritanceRequiredT] | None: + """Initialize an external handler (middleware registry or graceful shutdown handler). + + Args: + base_class: The base class to get subclass from + + Returns: + The initialized handler class or None if no handler is found + + Raises: + RuntimeError: If the handler has unimplemented abstract methods + """ + handler_type = base_class.__name__ + handler_cls = base_class.get_subclass() + if handler_cls is None: + logger.debug(f"[{handler_type} INIT] No self defined {handler_type.lower()} class found") + return None + + unimplemented_abstract_methods = framework_utils.get_unimplemented_abstract_methods(handler_cls) + if len(unimplemented_abstract_methods) > 0: + error_message = f"[{handler_type} INIT] Self defined {handler_type.lower()} class: {handler_cls.__name__} has unimplemented abstract methods: {unimplemented_abstract_methods}" + logger.error(error_message) + raise RuntimeError(error_message) + + logger.debug( + f"[{handler_type} INIT] Self defined {handler_type.lower()} class: {handler_cls.__name__}" + ) + logger.debug( + f"[{handler_type} INIT] Inject dependencies for external object: {handler_cls.__name__}" + ) + self.app_context.inject_dependencies_for_external_object(handler_cls) + return handler_cls + + def _init_middlewares(self) -> None: + handler_type = MiddlewareRegistry.__name__ + logger.debug(f"[{handler_type} INIT] Initialize middlewares...") + middleware_config_cls = self._init_external_handler(MiddlewareConfiguration) + if middleware_config_cls is None: + return + registry = MiddlewareRegistry() + logger.info(f"[{handler_type} INIT] Setup middlewares for registry: {registry.__class__.__name__}") + middleware_config_cls().configure_middlewares(registry) + logger.info(f"[{handler_type} INIT] Middlewares setup for registry: {registry.__class__.__name__} completed") + middleware_classes: list[Type[Middleware]] = registry.get_middleware_classes() + logger.info(f"[{handler_type} INIT] Middleware classes: {', '.join([middleware_class.__name__ for middleware_class in middleware_classes])}") + for middleware_class in middleware_classes: + logger.debug( + f"[{handler_type} INIT] Inject dependencies for middleware: {middleware_class.__name__}" + ) + self.app_context.inject_dependencies_for_external_object(middleware_class) + registry.apply_middlewares(self.fastapi) + logger.debug(f"[{handler_type} INIT] Middlewares initialized") + + def _init_graceful_shutdown(self) -> None: + handler_type = GracefulShutdownHandler.__name__ + logger.debug(f"[{handler_type} INIT] Initialize graceful shutdown...") + handler_cls: Optional[Type[GracefulShutdownHandler]] = self._init_external_handler(GracefulShutdownHandler) + if handler_cls is None: + return + + # Get shutdown configuration + shutdown_config = self.app_config.shutdown_config + logger.debug(f"[{handler_type} INIT] Shutdown timeout: {shutdown_config.timeout_seconds}s, enabled: {shutdown_config.enabled}") + + # Initialize handler with timeout configuration + self.shutdown_handler = handler_cls( + timeout_seconds=shutdown_config.timeout_seconds, + timeout_enabled=shutdown_config.enabled + ) # type: ignore + logger.debug(f"[{handler_type} INIT] Graceful shutdown initialized") + + def _configure_uvicorn_logging(self): + """Configure Uvicorn to use Loguru instead of default logging.""" + + # Configure Uvicorn to use Loguru + # Intercept standard logging and redirect to loguru + class InterceptHandler(logging.Handler): + def emit(self, record): + # Get corresponding Loguru level if it exists + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + # Find caller from where originated the logged message + frame, depth = logging.currentframe(), 2 + while frame and frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + + # Remove default uvicorn logger and add intercept handler + log_level = self.app_config.loguru_config.log_level.value + logging.basicConfig(handlers=[InterceptHandler()], level=log_level, force=True) - def __run_server(self) -> None: + def _run_server(self) -> None: + # Run uvicorn server uvicorn.run( self.fastapi, host=self.app_config.server_config.host, port=self.app_config.server_config.port, + log_config=None, # Disable uvicorn's default logging ) def run(self) -> None: try: - self.__configure_logging() - self.__init_app() - self.__init_controllers() + self._configure_logging() + self._init_app() + self._init_controllers() + self._init_middlewares() + self._init_graceful_shutdown() if self.app_config.server_config.enabled: - self.__run_server() + self._run_server() finally: + # Handle component lifecycle destruction self._handle_singleton_components_life_cycle(ComponentLifeCycle.Destruction) + # Handle graceful shutdown completion + if self.shutdown_handler: + self.shutdown_handler.complete_shutdown() + + \ No newline at end of file diff --git a/py_spring_core/core/entities/bean_collection.py b/py_spring_core/core/entities/bean_collection/bean_collection.py similarity index 100% rename from py_spring_core/core/entities/bean_collection.py rename to py_spring_core/core/entities/bean_collection/bean_collection.py diff --git a/py_spring_core/core/entities/component.py b/py_spring_core/core/entities/component/component.py similarity index 93% rename from py_spring_core/core/entities/component.py rename to py_spring_core/core/entities/component/component.py index d9b7e8f..55939f6 100644 --- a/py_spring_core/core/entities/component.py +++ b/py_spring_core/core/entities/component/component.py @@ -34,17 +34,20 @@ class Component: The `get_scope()` and `set_scope()` methods allow you to get and set the scope of the component. The lifecycle hooks are: - - `post_initialize()`: Called after the component is initialized. + - `post_construct()`: Called after the component is initialized. - `pre_destroy()`: Called before the component is destroyed. The `finish_initialization_cycle()` and `finish_destruction_cycle()` methods are final and call the corresponding lifecycle hooks in the correct order. """ class Config: + name: str = "" scope: ComponentScope = ComponentScope.Singleton @classmethod def get_name(cls) -> str: + if hasattr(cls.Config, "name") and cls.Config.name: + return cls.Config.name return cls.__name__ @classmethod diff --git a/py_spring_core/core/entities/controllers/rest_controller.py b/py_spring_core/core/entities/controllers/rest_controller.py index 20d81f8..93f6c6f 100644 --- a/py_spring_core/core/entities/controllers/rest_controller.py +++ b/py_spring_core/core/entities/controllers/rest_controller.py @@ -1,5 +1,11 @@ +from functools import partial +from typing import Iterable + from fastapi import APIRouter, FastAPI +from py_spring_core.core.entities.controllers.route_mapping import RouteRegistration +from py_spring_core.core.entities.middlewares.middleware import Middleware + class RestController: """ @@ -11,7 +17,7 @@ class RestController: - Providing access to the FastAPI `APIRouter` and `FastAPI` app instances - Exposing the controller's configuration, including the URL prefix - Subclasses of `RestController` should override the `register_routes` and `register_middlewares` methods to add their own routes and middleware to the controller. + Subclasses of `RestController` should override the `register_routes` methods to add their own routes and middleware to the controller. """ app: FastAPI @@ -20,9 +26,34 @@ class RestController: class Config: prefix: str = "" - def register_routes(self) -> None: ... + def post_construct(self) -> None: ... - def register_middlewares(self) -> None: ... + def _register_decorated_routes(self, routes: Iterable[RouteRegistration]) -> None: + for route in routes: + bound_method = partial(route.func, self) + self.router.add_api_route( + path=route.path, + endpoint=bound_method, + methods=[route.method.value], + response_model=route.response_model, + status_code=route.status_code, + tags=route.tags, + dependencies=route.dependencies, + summary=route.summary, + description=route.description, + response_description=route.response_description, + responses=route.responses, + deprecated=route.deprecated, + operation_id=route.operation_id, + response_model_include=route.response_model_include, + response_model_exclude=route.response_model_exclude, + response_model_by_alias=route.response_model_by_alias, + response_model_exclude_unset=route.response_model_exclude_unset, + response_model_exclude_defaults=route.response_model_exclude_defaults, + response_model_exclude_none=route.response_model_exclude_none, + include_in_schema=route.include_in_schema, + name=route.name, + ) def get_router(self) -> APIRouter: return self.router diff --git a/py_spring_core/core/entities/controllers/route_mapping.py b/py_spring_core/core/entities/controllers/route_mapping.py new file mode 100644 index 0000000..0e3f050 --- /dev/null +++ b/py_spring_core/core/entities/controllers/route_mapping.py @@ -0,0 +1,283 @@ +from enum import Enum +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Set, Union + +from pydantic import BaseModel + + +class HTTPMethod(str, Enum): + """HTTP methods enumeration for route definitions. + + This enum defines the supported HTTP methods that can be used + with route decorators in the PySpring framework. + """ + GET = "GET" + POST = "POST" + PUT = "PUT" + DELETE = "DELETE" + PATCH = "PATCH" + + +class RouteRegistration(BaseModel): + """Model representing a route registration with all FastAPI-compatible parameters. + + This class encapsulates all the information needed to register a route, + including the HTTP method, path, handler function, and various FastAPI + configuration options. + + Attributes: + class_name (str): Name of the controller class containing the route + method (HTTPMethod): HTTP method for the route + path (str): URL path pattern for the route + func (Callable): The handler function for the route + response_model (Any, optional): Pydantic model for response serialization + status_code (int, optional): Default HTTP status code for successful responses + tags (List[Union[str, Enum]], optional): OpenAPI tags for documentation + dependencies (List[Any], optional): FastAPI dependencies + summary (str, optional): Short summary for OpenAPI documentation + description (str, optional): Detailed description for OpenAPI documentation + response_description (str): Description of successful response + responses (Dict[Union[int, str], Dict[str, Any]], optional): Additional response definitions + deprecated (bool, optional): Whether the endpoint is deprecated + operation_id (str, optional): Unique operation ID for OpenAPI + response_model_include (Set[str], optional): Fields to include in response model + response_model_exclude (Set[str], optional): Fields to exclude from response model + response_model_by_alias (bool): Whether to use field aliases in response + response_model_exclude_unset (bool): Whether to exclude unset fields + response_model_exclude_defaults (bool): Whether to exclude default values + response_model_exclude_none (bool): Whether to exclude None values + include_in_schema (bool): Whether to include in OpenAPI schema + name (str, optional): Custom name for the route + """ + class_name: str + method: HTTPMethod + path: str + func: Callable + response_model: Any = None + status_code: Optional[int] = None + tags: Optional[List[Union[str, Enum]]] = None + dependencies: Optional[List[Any]] = None + summary: Optional[str] = None + description: Optional[str] = None + response_description: str = "Successful Response" + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None + deprecated: Optional[bool] = None + operation_id: Optional[str] = None + response_model_include: Optional[Set[str]] = None + response_model_exclude: Optional[Set[str]] = None + response_model_by_alias: bool = True + response_model_exclude_unset: bool = False + response_model_exclude_defaults: bool = False + response_model_exclude_none: bool = False + include_in_schema: bool = True + name: Optional[str] = None + + def __eq__(self, other: Any) -> bool: + """Check equality based on HTTP method and path. + + Two route registrations are considered equal if they have + the same HTTP method and path, regardless of other attributes. + + Args: + other (Any): Object to compare with + + Returns: + bool: True if equal, False otherwise + """ + if not isinstance(other, RouteRegistration): + return False + return self.method == other.method and self.path == other.path + + def __hash__(self) -> int: + """Generate hash based on HTTP method and path. + + This allows RouteRegistration objects to be used in sets + and as dictionary keys. + + Returns: + int: Hash value based on method and path + """ + return hash((self.method, self.path)) + + +class RouteMapping: + """Registry for storing and managing route registrations by controller class. + + This class maintains a mapping of controller class names to their + associated route registrations, enabling efficient route lookup + and management within the PySpring framework. + + Attributes: + routes (dict[str, set[RouteRegistration]]): Mapping of class names to route sets + """ + routes: dict[str, set[RouteRegistration]] = {} + + @classmethod + def register_route(cls, route_registration: RouteRegistration) -> None: + """Register a route for a specific controller class. + + Adds the given route registration to the routes mapping, + creating a new set for the class if it doesn't exist. + + Args: + route_registration (RouteRegistration): The route to register + """ + optional_routes = cls.routes.get(route_registration.class_name, None) + if optional_routes is None: + cls.routes[route_registration.class_name] = set() + cls.routes[route_registration.class_name].add(route_registration) + + +def _create_route_decorator(method: HTTPMethod): + """Create a route decorator factory for a specific HTTP method. + + This function generates decorator factories that can be used to create + route decorators with FastAPI-compatible parameters. + + Args: + method (HTTPMethod): The HTTP method for routes created by this decorator + + Returns: + Callable: A decorator factory function that accepts route parameters + """ + def decorator_factory( + path: str, + *, + response_model: Any = None, + status_code: Optional[int] = None, + tags: Optional[List[Union[str, Enum]]] = None, + dependencies: Optional[List[Any]] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + response_description: str = "Successful Response", + responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + deprecated: Optional[bool] = None, + operation_id: Optional[str] = None, + response_model_include: Optional[Set[str]] = None, + response_model_exclude: Optional[Set[str]] = None, + response_model_by_alias: bool = True, + response_model_exclude_unset: bool = False, + response_model_exclude_defaults: bool = False, + response_model_exclude_none: bool = False, + include_in_schema: bool = True, + name: Optional[str] = None, + ): + """Create a route decorator with the specified parameters. + + Args: + path (str): URL path pattern for the route + response_model (Any, optional): Pydantic model for response serialization + status_code (int, optional): Default HTTP status code + tags (List[Union[str, Enum]], optional): OpenAPI tags + dependencies (List[Any], optional): FastAPI dependencies + summary (str, optional): Route summary for documentation + description (str, optional): Route description for documentation + response_description (str): Description of successful response + responses (Dict[Union[int, str], Dict[str, Any]], optional): Additional responses + deprecated (bool, optional): Whether the endpoint is deprecated + operation_id (str, optional): Unique operation ID + response_model_include (Set[str], optional): Fields to include in response + response_model_exclude (Set[str], optional): Fields to exclude from response + response_model_by_alias (bool): Whether to use field aliases + response_model_exclude_unset (bool): Whether to exclude unset fields + response_model_exclude_defaults (bool): Whether to exclude defaults + response_model_exclude_none (bool): Whether to exclude None values + include_in_schema (bool): Whether to include in OpenAPI schema + name (str, optional): Custom name for the route + + Returns: + Callable: A decorator function that registers the route + """ + def decorator(func: Callable): + """Decorate a function to register it as a route. + + Args: + func (Callable): The handler function to decorate + + Returns: + Callable: The wrapped function + """ + class_name = func.__qualname__.split(".")[0] + route_registration = RouteRegistration( + class_name=class_name, + method=method, + path=path, + func=func, + response_model=response_model, + status_code=status_code, + tags=tags, + dependencies=dependencies, + summary=summary or func.__name__, + description=description or func.__doc__, + response_description=response_description, + responses=responses, + deprecated=deprecated, + operation_id=operation_id, + response_model_include=response_model_include, + response_model_exclude=response_model_exclude, + response_model_by_alias=response_model_by_alias, + response_model_exclude_unset=response_model_exclude_unset, + response_model_exclude_defaults=response_model_exclude_defaults, + response_model_exclude_none=response_model_exclude_none, + include_in_schema=include_in_schema, + name=name, + ) + RouteMapping.register_route(route_registration) + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any): + return func(*args, **kwargs) + + return wrapper + + return decorator + + return decorator_factory + + +# Route decorator factories for different HTTP methods +GetMapping = _create_route_decorator(HTTPMethod.GET) +"""Decorator factory for creating GET route handlers. + +Example: + @GetMapping("/users") + def get_users(self): + return {"users": []} +""" + +PostMapping = _create_route_decorator(HTTPMethod.POST) +"""Decorator factory for creating POST route handlers. + +Example: + @PostMapping("/users", response_model=UserResponse) + def create_user(self, user: UserCreate): + return create_new_user(user) +""" + +PutMapping = _create_route_decorator(HTTPMethod.PUT) +"""Decorator factory for creating PUT route handlers. + +Example: + @PutMapping("/users/{user_id}") + def update_user(self, user_id: int, user: UserUpdate): + return update_existing_user(user_id, user) +""" + +DeleteMapping = _create_route_decorator(HTTPMethod.DELETE) +"""Decorator factory for creating DELETE route handlers. + +Example: + @DeleteMapping("/users/{user_id}") + def delete_user(self, user_id: int): + delete_existing_user(user_id) + return {"message": "User deleted"} +""" + +PatchMapping = _create_route_decorator(HTTPMethod.PATCH) +"""Decorator factory for creating PATCH route handlers. + +Example: + @PatchMapping("/users/{user_id}") + def partial_update_user(self, user_id: int, user: UserPatch): + return patch_existing_user(user_id, user) +""" diff --git a/py_spring_core/core/entities/entity_provider.py b/py_spring_core/core/entities/entity_provider/entity_provider.py similarity index 79% rename from py_spring_core/core/entities/entity_provider.py rename to py_spring_core/core/entities/entity_provider/entity_provider.py index 5481add..25d35d5 100644 --- a/py_spring_core/core/entities/entity_provider.py +++ b/py_spring_core/core/entities/entity_provider/entity_provider.py @@ -1,9 +1,9 @@ -from dataclasses import field, dataclass +from dataclasses import dataclass, field from typing import Any, Optional, Type from py_spring_core.core.application.commons import AppEntities -from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection +from py_spring_core.core.entities.component.component import Component from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.properties.properties import Properties @@ -22,10 +22,10 @@ class EntityProvider: properties_classes: list[Type[Properties]] = field(default_factory=list) rest_controller_classes: list[Type[RestController]] = field(default_factory=list) depends_on: list[Type[AppEntities]] = field(default_factory=list) - extneral_dependencies: list[Any] = field(default_factory=list) + external_dependencies: list[Any] = field(default_factory=list) app_context: Optional["ApplicationContext"] = None - def get_entities(self) -> list[Type[object]]: + def get_entities(self) -> list[Type[AppEntities]]: return [ *self.component_classes, *self.bean_collection_classes, diff --git a/py_spring_core/core/entities/middlewares/middleware.py b/py_spring_core/core/entities/middlewares/middleware.py new file mode 100644 index 0000000..d861309 --- /dev/null +++ b/py_spring_core/core/entities/middlewares/middleware.py @@ -0,0 +1,58 @@ +from abc import abstractmethod +from typing import Awaitable, Callable + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + + +class Middleware(BaseHTTPMiddleware): + """ + Middleware base class, inherits from FastAPI's BaseHTTPMiddleware + Simpler to use, only need to implement the process_request method + """ + + def should_skip(self, request: Request) -> bool: + """ + Method to determine if the middleware should be skipped + + Args: + request: FastAPI request object + + Returns: + bool: True if the middleware should be skipped, False otherwise, default is False + """ + return False + + @abstractmethod + async def process_request(self, request: Request) -> Response | None: + """ + Method to process requests + + Args: + request: FastAPI request object + + Returns: + Response | None: If Response is returned, it will be directly returned to the client + If None is returned, continue to execute the next middleware or route handler + """ + pass + + async def dispatch( + self, request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + """ + Middleware dispatch method, automatically called by FastAPI + If should_skip returns True, the middleware will be skipped + """ + if self.should_skip(request): + return await call_next(request) + + # First execute custom request processing logic + response = await self.process_request(request) + + # If a response is returned, return it directly + if response is not None: + return response + + # Otherwise continue to execute the next middleware or route handler + return await call_next(request) diff --git a/py_spring_core/core/entities/middlewares/middleware_registry.py b/py_spring_core/core/entities/middlewares/middleware_registry.py new file mode 100644 index 0000000..6fd7d93 --- /dev/null +++ b/py_spring_core/core/entities/middlewares/middleware_registry.py @@ -0,0 +1,203 @@ +from abc import ABC, abstractmethod +from typing import Type + +from fastapi import FastAPI + +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.interfaces.single_inheritance_required import ( + SingleInheritanceRequired, +) + + +class MiddlewareRegistry: + """ + Middleware registry for managing all middlewares + + This registry pattern eliminates the need for manual middleware registration. + The framework automatically handles middleware registration and execution order. + + Multiple middleware execution order: + When multiple middlewares are registered through this registry, they are automatically + applied to the FastAPI application in the order they are returned by get_middleware_classes(). + Each middleware wraps the application, forming a stack. The last middleware added is the outermost, + and the first is the innermost. + + On the request path, the outermost middleware runs first. + On the response path, it runs last. + + For example, if get_middleware_classes() returns [MiddlewareA, MiddlewareB] + This results in the following execution order: + Request: MiddlewareB → MiddlewareA → route + Response: route → MiddlewareA → MiddlewareB + This stacking behavior ensures that middlewares are executed in a predictable and controllable order. + """ + + def __init__(self): + """ + Initialize the middleware registry. + """ + self._middlewares: list[Type[Middleware]] = [] + + def add_middleware(self, middleware_class: Type[Middleware]) -> None: + """ + Add middleware to the end of the list. + + Args: + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + self._middlewares.append(middleware_class) + + def add_at_index(self, index: int, middleware_class: Type[Middleware]) -> None: + """ + Insert middleware at a specific index position. + + Args: + index: The position to insert at (0-based) + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered or index is invalid + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + if index < 0 or index > len(self._middlewares): + raise ValueError(f"Index {index} is out of range (0-{len(self._middlewares)})") + self._middlewares.insert(index, middleware_class) + + def add_before(self, target_middleware: Type[Middleware], middleware_class: Type[Middleware]) -> None: + """ + Insert middleware before the target middleware. + + Args: + target_middleware: The middleware to insert before + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered or target not found + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + if target_middleware not in self._middlewares: + raise ValueError(f"Target middleware {target_middleware.__name__} not found") + index = self._middlewares.index(target_middleware) + self._middlewares.insert(index, middleware_class) + + def add_after(self, target_middleware: Type[Middleware], middleware_class: Type[Middleware]) -> None: + """ + Insert middleware after the target middleware. + + Args: + target_middleware: The middleware to insert after + middleware_class: The middleware class to add + + Raises: + ValueError: If middleware is already registered or target not found + """ + if middleware_class in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} is already registered") + if target_middleware not in self._middlewares: + raise ValueError(f"Target middleware {target_middleware.__name__} not found") + index = self._middlewares.index(target_middleware) + self._middlewares.insert(index + 1, middleware_class) + + def remove_middleware(self, middleware_class: Type[Middleware]) -> None: + """ + Remove a middleware from the registry. + + Args: + middleware_class: The middleware class to remove + + Raises: + ValueError: If middleware is not found + """ + if middleware_class not in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} not found") + self._middlewares.remove(middleware_class) + + def clear_middlewares(self) -> None: + """Remove all middlewares from the registry.""" + self._middlewares.clear() + + def has_middleware(self, middleware_class: Type[Middleware]) -> bool: + """ + Check if a middleware is registered. + + Args: + middleware_class: The middleware class to check + + Returns: + bool: True if middleware is registered, False otherwise + """ + return middleware_class in self._middlewares + + def get_middleware_count(self) -> int: + """ + Get the number of registered middlewares. + + Returns: + int: Number of registered middlewares + """ + return len(self._middlewares) + + def get_middleware_index(self, middleware_class: Type[Middleware]) -> int: + """ + Get the index of a middleware in the registry. + + Args: + middleware_class: The middleware class to find + + Returns: + int: The index of the middleware + + Raises: + ValueError: If middleware is not found + """ + if middleware_class not in self._middlewares: + raise ValueError(f"Middleware {middleware_class.__name__} not found") + return self._middlewares.index(middleware_class) + + + + def get_middleware_classes(self) -> list[Type[Middleware]]: + """ + Get all registered middleware classes. + + Returns: + List[Type[Middleware]]: List of middleware classes in registration order + """ + return self._middlewares.copy() + + def apply_middlewares(self, app: FastAPI) -> FastAPI: + """ + Apply middlewares to FastAPI application. + + Iterates through all registered middlewares and applies them to the FastAPI + application instance in the order they were registered. + + Args: + app: FastAPI application instance + + Returns: + FastAPI: FastAPI instance with applied middlewares + """ + for middleware_class in self.get_middleware_classes(): + app.add_middleware(middleware_class) + return app + + + +class MiddlewareConfiguration(SingleInheritanceRequired["MiddlewareConfiguration"]): + """ + Middleware configuration for managing middleware registration and execution order. + """ + + def configure_middlewares(self, registry: MiddlewareRegistry) -> None: + """ + Setup middlewares for the registry. + """ + pass \ No newline at end of file diff --git a/py_spring_core/core/entities/properties/properties_loader.py b/py_spring_core/core/entities/properties/properties_loader.py index 08f40a6..983d803 100644 --- a/py_spring_core/core/entities/properties/properties_loader.py +++ b/py_spring_core/core/entities/properties/properties_loader.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Type +from typing import Any, Callable, Optional, Type import cachetools import yaml @@ -32,7 +32,7 @@ def __init__( self.properties_classes = properties_classes self.properties_class_map = self._load_classes_as_map() - self.extension_loader_lookup = { + self.extension_loader_lookup: dict[str, Callable[[str], dict[str, Any]]] = { "json": json.loads, "yaml": yaml.safe_load, "yml": yaml.safe_load, diff --git a/py_spring_core/core/interfaces/application_context_required.py b/py_spring_core/core/interfaces/application_context_required.py new file mode 100644 index 0000000..83d7729 --- /dev/null +++ b/py_spring_core/core/interfaces/application_context_required.py @@ -0,0 +1,36 @@ +from typing import Optional + +from py_spring_core.core.application.context.application_context import ( + ApplicationContext, +) + + +class ApplicationContextRequired: + """ + A mixin class that provides access to the ApplicationContext for classes that need it. + + This class serves as a base for components that require access to the ApplicationContext. + It provides class-level methods to set and retrieve the ApplicationContext instance. + + Usage: + class MyComponent(ApplicationContextRequired): + def some_method(self): + context = self.get_application_context() + # Use the context... + + Note: + The ApplicationContext must be set before attempting to retrieve it, + otherwise a RuntimeError will be raised. + """ + + _app_context: Optional[ApplicationContext] = None + + @classmethod + def set_application_context(cls, application_context: ApplicationContext) -> None: + cls._app_context = application_context + + @classmethod + def get_application_context(cls) -> ApplicationContext: + if cls._app_context is None: + raise RuntimeError("ApplicationContext is not set") + return cls._app_context diff --git a/py_spring_core/core/interfaces/graceful_shutdown_handler.py b/py_spring_core/core/interfaces/graceful_shutdown_handler.py new file mode 100644 index 0000000..504026f --- /dev/null +++ b/py_spring_core/core/interfaces/graceful_shutdown_handler.py @@ -0,0 +1,143 @@ + + +from abc import ABC, abstractmethod +from enum import Enum, auto +import os +import signal +import threading +import time +from types import FrameType +from typing import Optional + +from loguru import logger + +from py_spring_core.core.interfaces.single_inheritance_required import SingleInheritanceRequired + +class ShutdownType(Enum): + MANUAL = auto() # e.g., Ctrl+C + SIGTERM = auto() # e.g., docker stop, systemctl stop + TIMEOUT = auto() # e.g., shutdown triggered by time constraint + ERROR = auto() # e.g., unrecoverable fault + UNKNOWN = auto() + + +class GracefulShutdownHandler(SingleInheritanceRequired, ABC): + """ + A mixin class that provides a method to handle graceful shutdown. + """ + + def __init__(self, timeout_seconds: float, timeout_enabled: bool) -> None: + self._shutdown_event = threading.Event() + self._shutdown_type: Optional[ShutdownType] = None + self._timeout_seconds = timeout_seconds + self._timeout_enabled = timeout_enabled + self._timeout_timer: Optional[threading.Timer] = None + self._shutdown_start_time: Optional[float] = None + + signal.signal(signal.SIGINT, self._handle_sigint) + signal.signal(signal.SIGTERM, self._handle_sigterm) + + def _handle_sigint(self, signum: int, frame: Optional[FrameType]) -> None: + try: + # Check if shutdown is already in progress to prevent duplicate execution + if self._shutdown_event.is_set(): + logger.debug("[Signal] SIGINT ignored - shutdown already in progress") + return + + logger.info("[Signal] SIGINT received") + self._shutdown_type = ShutdownType.MANUAL + self._shutdown_event.set() + self._start_shutdown_timer() + self.on_shutdown(ShutdownType.MANUAL) + except Exception as error: + self.on_error(error) + + def _handle_sigterm(self, signum: int, frame: Optional[FrameType]) -> None: + try: + # Check if shutdown is already in progress to prevent duplicate execution + if self._shutdown_event.is_set(): + logger.debug("[Signal] SIGTERM ignored - shutdown already in progress") + return + + logger.info("[Signal] SIGTERM received") + self._shutdown_type = ShutdownType.SIGTERM + self._shutdown_event.set() + self._start_shutdown_timer() + self.on_shutdown(ShutdownType.SIGTERM) + except Exception as error: + self.on_error(error) + + def _start_shutdown_timer(self) -> None: + """Start the shutdown timeout timer if enabled.""" + if not self._timeout_enabled: + return + + self._shutdown_start_time = time.time() + logger.info(f"[Shutdown Timer] Starting shutdown timer for {self._timeout_seconds} seconds") + + self._timeout_timer = threading.Timer(self._timeout_seconds, self._handle_timeout) + self._timeout_timer.daemon = True + self._timeout_timer.start() + + def _handle_timeout(self) -> None: + """Handle shutdown timeout.""" + try: + if not self._shutdown_event.is_set(): + return # Shutdown was not initiated, ignore timeout + + logger.info(f"[Shutdown Timer] Shutdown timeout reached after {self._timeout_seconds} seconds") + self._shutdown_type = ShutdownType.TIMEOUT + self.on_timeout() + except Exception as error: + self.on_error(error) + finally: + logger.critical(f"[Shutdown Timer] Timer exited with grace period of {self._timeout_seconds} seconds, exiting application") + os._exit(0) + + def complete_shutdown(self) -> None: + """Mark shutdown as complete and cancel timeout timer.""" + if self._timeout_timer and self._timeout_timer.is_alive(): + self._timeout_timer.cancel() + + if self._shutdown_start_time: + elapsed = time.time() - self._shutdown_start_time + logger.success(f"[Shutdown Timer] Shutdown completed successfully in {elapsed:.2f} seconds") + + def is_shutdown(self) -> bool: + return self._shutdown_event.is_set() + + def get_type(self) -> Optional[ShutdownType]: + return self._shutdown_type + + def get_timeout_seconds(self) -> float: + return self._timeout_seconds + + def is_timeout_enabled(self) -> bool: + return self._timeout_enabled + + def get_shutdown_elapsed_time(self) -> Optional[float]: + """Get the elapsed time since shutdown started.""" + if self._shutdown_start_time is None: + return None + return time.time() - self._shutdown_start_time + + @abstractmethod + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + """ + Handle shutdown. + """ + ... + + @abstractmethod + def on_timeout(self) -> None: + """ + Handle timeout. + """ + ... + + @abstractmethod + def on_error(self, error: Exception) -> None: + """ + Handle error. + """ + ... \ No newline at end of file diff --git a/py_spring_core/core/interfaces/single_inheritance_required.py b/py_spring_core/core/interfaces/single_inheritance_required.py new file mode 100644 index 0000000..525ea1a --- /dev/null +++ b/py_spring_core/core/interfaces/single_inheritance_required.py @@ -0,0 +1,36 @@ +from abc import ABC +from typing import Generic, Optional, Type, TypeVar, cast + +T = TypeVar("T") + + +class SingleInheritanceRequired(Generic[T], ABC): + """ + A base class that ensures only one subclass can be inherited from it. + This enforces the single inheritance constraint for specific component types. + """ + + @classmethod + def check_only_one_subclass_allowed(cls) -> None: + """ + Check if the subclass is allowed to be inherited. + """ + class_dict: dict[str, Type[SingleInheritanceRequired[T]]] = {} + for subclass in cls.__subclasses__(): + if subclass.__name__ in class_dict: + continue + class_dict[subclass.__name__] = subclass + if len(class_dict) > 1: + raise ValueError( + f"Only one subclass is allowed for {cls.__name__}, but {len(class_dict)} subclasses: {[subclass.__name__ for subclass in class_dict.values()]} found" + ) + + @classmethod + def get_subclass(cls) -> Optional[Type[T]]: + """ + Get the subclass of the component. + """ + cls.check_only_one_subclass_allowed() + if len(cls.__subclasses__()) == 0: + return + return cast(Type[T], cls.__subclasses__()[0]) diff --git a/py_spring_core/core/utils.py b/py_spring_core/core/utils.py index 629cea1..16da747 100644 --- a/py_spring_core/core/utils.py +++ b/py_spring_core/core/utils.py @@ -1,14 +1,17 @@ -import importlib.util import inspect -from pathlib import Path -from typing import Any, Callable, Iterable, Type, get_type_hints +from abc import ABC +from typing import Any, Iterable, Type from loguru import logger +from ..commons.module_importer import ModuleImporter + +# Global module importer instance +_module_importer = ModuleImporter() + def dynamically_import_modules( module_paths: Iterable[str], - is_ignore_error: bool = True, target_subclasses: Iterable[Type[object]] = [], ) -> set[Type[object]]: """ @@ -17,115 +20,47 @@ def dynamically_import_modules( Args: module_paths (Iterable[str]): The file paths of the modules to import. is_ignore_error (bool, optional): Whether to ignore any errors that occur during the import process. Defaults to True. + target_subclasses (Iterable[Type[object]], optional): Target subclasses to filter. Defaults to []. Raises: Exception: If an error occurs during the import process and `is_ignore_error` is False. """ - all_loaded_classes: list[Type[object]] = [] - - for module_path in module_paths: - file_path = Path(module_path).resolve() - module_name = file_path.stem - logger.info(f"[MODULE IMPORT] Import module path: {file_path}") - # Create a module specification - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - logger.warning( - f"[DYNAMICALLY MODULE IMPORT] Could not create spec for {module_name}" - ) - continue - - # Create a new module based on the specification - module = importlib.util.module_from_spec(spec) - if spec.loader is None: - logger.warning( - f"[DYNAMICALLY MODULE IMPORT] No loader found for {module_name}" - ) - continue - - # Execute the module in its own namespace - - logger.info(f"[DYNAMICALLY MODULE IMPORT] Import module: {module_name}") - try: - spec.loader.exec_module(module) - logger.success( - f"[DYNAMICALLY MODULE IMPORT] Successfully imported {module_name}" - ) - except Exception as error: - logger.warning(error) - if not is_ignore_error: - raise error - - loaded_classes = [] - for attr in dir(module): - obj = getattr(module, attr) - if attr.startswith("__"): - continue - if not inspect.isclass(obj): - continue - loaded_classes.append(obj) - all_loaded_classes.extend(loaded_classes) - - returned_target_classes: set[Type[object]] = set() - for target_cls in target_subclasses: - for loaded_class in all_loaded_classes: - if loaded_class in target_subclasses: - continue - if issubclass(loaded_class, target_cls): - returned_target_classes.add(loaded_class) - - return returned_target_classes - - -class TypeHintError(Exception): ... - - -def check_type_hints_for_callable(func: Callable[..., Any]) -> None: - RETURN_ID = "return" - func_qualname_list = func.__qualname__.split(".") - is_class_callable = True if len(func_qualname_list) == 2 else False - class_name = func_qualname_list[0] if is_class_callable else "" - - func_name = func.__name__ - args_type_hints = get_type_hints(func) - arg_type = inspect.getargs(func.__code__) - code_path = inspect.getsourcefile(func) - if RETURN_ID not in args_type_hints: - raise TypeHintError( - f"Type hints for 'return type' not provided for the function: {class_name}.{func_name}, path: {code_path}" - ) - - # plue one is for return type, return type is not included in co_argcount if it is a simple function, - # for member functions, self is included in co_varnames, but not in type hints, so plus 0 - arguments = arg_type.args - argument_count = len(arguments) - if argument_count == 0: - return - if len(args_type_hints) == 0: - raise TypeHintError( - f"Type hints not provided for the function: {class_name}.{func_name}, arguments: {arguments}, current type hints: {args_type_hints}, path: {code_path}" - ) - - if len(args_type_hints) != argument_count: - missing_type_hints_args = [ - arg for arg in arguments if arg not in args_type_hints and arg != "self" - ] - if len(missing_type_hints_args) == 0: - return - raise TypeHintError( - f"Type hints not fully provided: {class_name}.{func_name}, arguments: {arguments}, current type hints: {args_type_hints}, missingg type hints: {','.join(missing_type_hints_args)}, path: {code_path}" - ) - - -def check_type_hints_for_class(_cls: Type[Any], skip_attrs: list[str] = list()) -> None: - for attr in dir(_cls): - if attr in skip_attrs: - continue - if attr.startswith("__"): - continue - attr_obj = getattr(_cls, attr) - if not callable(attr_obj): - continue - if not hasattr(attr_obj, "__annotations__"): - continue - check_type_hints_for_callable(attr_obj) + return _module_importer.import_classes_from_paths( + file_paths=module_paths, + target_subclasses=target_subclasses, + ) + + +def clear_module_cache() -> None: + """Clear the global module cache. Useful for testing or when you need to force re-import.""" + _module_importer.clear_cache() + + +def get_unimplemented_abstract_methods(cls: Type[Any]) -> set[str]: + """ + Returns a set of abstract method names not implemented in the given class. + Args: + cls (Type[Any]): A subclass of abc.ABC + + Returns: + set[str]: A set of method names that are abstract but not yet implemented + """ + if not isinstance(cls, type): + raise TypeError("Expected a class type.") + + if not issubclass(cls, ABC): + raise TypeError("Expected a subclass of abc.ABC.") + + abstract_methods: set[str] = set() + for base in cls.__mro__: + base_abstracts = getattr(base, "__abstractmethods__", set()) + abstract_methods = abstract_methods.union(base_abstracts) + + implemented_methods: set[str] = { + attr + for attr in dir(cls) + if callable(getattr(cls, attr)) + and not getattr(getattr(cls, attr), "__isabstractmethod__", False) + } + + return abstract_methods.difference(implemented_methods) diff --git a/py_spring_core/event/application_event_handler_registry.py b/py_spring_core/event/application_event_handler_registry.py new file mode 100644 index 0000000..14aeaee --- /dev/null +++ b/py_spring_core/event/application_event_handler_registry.py @@ -0,0 +1,125 @@ +from threading import Thread +from typing import Callable, ClassVar, Type + +from loguru import logger +from pydantic import BaseModel + +from py_spring_core.core.entities.component.component import Component +from py_spring_core.core.interfaces.application_context_required import ( + ApplicationContextRequired, +) +from py_spring_core.event.commons import ApplicationEvent, EventQueue + +EventHandlerT = Callable[[Component, ApplicationEvent], None] + + +def EventListener(event_type: Type[ApplicationEvent]) -> Callable: + """ + The EventListener decorator is used to register an event handler for an application event. + It is responsible for binding an event handler to a component and a function. + """ + + def decorator(func: EventHandlerT) -> None: + if not issubclass(event_type, ApplicationEvent): + raise ValueError(f"Event type must be a subclass of ApplicationEvent") + + ApplicationEventHandlerRegistry.register_event_handler(event_type, func) + + return decorator + + +class EventHandler(BaseModel): + """ + The EventHandler class is a model that represents an event handler for an application event. + It is responsible for binding an event handler to a component and a function. + """ + + class_name: str + func_name: str + event_type: Type[ApplicationEvent] + func: EventHandlerT + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EventHandler): + return False + return self.class_name == other.class_name and self.func_name == other.func_name + + def __hash__(self) -> int: + return hash((self.class_name, self.func_name)) + + +class ApplicationEventHandlerRegistry(Component, ApplicationContextRequired): + """ + The ApplicationEventHandlerRegistry is a component that registers event handlers for application events. + It is responsible for binding event handlers to their corresponding components and handling event messages. + + The class performs the following key tasks: + - Registers event handlers for application events + - Binds event handlers to their corresponding components + """ + + _class_event_handlers: ClassVar[dict[str, list[EventHandler]]] = {} + + def __init__(self) -> None: + self._event_handlers: dict[str, list[EventHandler]] = {} + self._event_message_queue = EventQueue.queue + + def post_construct(self) -> None: + logger.info("Initializing event handlers...") + self._init_event_handlers() + logger.info("Starting event message handler thread...") + Thread(target=self._handle_messages, daemon=True).start() + + def _init_event_handlers(self) -> None: + app_context = self.get_application_context() + # get_name might be different from the class name, so we use the class name for function binding + self.component_instance_map = { + component.__class__.__name__: component + for component in app_context.get_singleton_component_instances() + } + self._event_handlers = self._class_event_handlers + + @classmethod + def register_event_handler( + cls, event_type: Type[ApplicationEvent], handler: EventHandlerT + ): + event_name = event_type.__name__ + func_name_parts = handler.__qualname__.split(".") + if len(func_name_parts) != 2: + raise ValueError(f"Handler must be a member function of a class") + class_name, func_name = func_name_parts + if event_name not in cls._class_event_handlers: + cls._class_event_handlers[event_name] = [] + event_handler = EventHandler( + class_name=class_name, + func_name=func_name, + event_type=event_type, + func=handler, + ) + if event_handler not in cls._class_event_handlers[event_name]: + cls._class_event_handlers[event_name].append(event_handler) + + def get_event_handlers( + self, event_type: Type[ApplicationEvent] + ) -> list[EventHandler]: + event_name = event_type.__name__ + handlers = self._event_handlers.get(event_name, []) + return handlers + + def _handle_messages(self) -> None: + logger.info("Event message handler thread started...") + while True: + message = self._event_message_queue.get() + for handler in self.get_event_handlers(message.__class__): + try: + optional_instance = self.component_instance_map.get( + handler.class_name, None + ) + if optional_instance is None: + logger.error( + f"Component instance not found for handler: {handler.class_name}" + ) + continue + handler.func(optional_instance, message) + except Exception as error: + logger.error(f"Error handling event: {error}") diff --git a/py_spring_core/event/application_event_publisher.py b/py_spring_core/event/application_event_publisher.py new file mode 100644 index 0000000..dccbe88 --- /dev/null +++ b/py_spring_core/event/application_event_publisher.py @@ -0,0 +1,26 @@ +from typing import TypeVar + +from py_spring_core.core.entities.component.component import Component +from py_spring_core.event.application_event_handler_registry import ( + ApplicationEvent, + ApplicationEventHandlerRegistry, +) +from py_spring_core.event.commons import EventQueue + +T = TypeVar("T", bound=ApplicationEvent) + + +class ApplicationEventPublisher(Component): + """ + The ApplicationEventPublisher is a component that publishes application events. + It is responsible for publishing application events to the event message queue. + + The class performs the following key tasks: + - Publishes application events to the event message queue + """ + + def __init__(self): + self.event_message_queue = EventQueue.queue + + def publish(self, event: ApplicationEvent) -> None: + self.event_message_queue.put(event) diff --git a/py_spring_core/event/commons.py b/py_spring_core/event/commons.py new file mode 100644 index 0000000..3ba14a6 --- /dev/null +++ b/py_spring_core/event/commons.py @@ -0,0 +1,10 @@ +from queue import Queue + +from pydantic import BaseModel + + +class ApplicationEvent(BaseModel): ... + + +class EventQueue: + queue: Queue[ApplicationEvent] = Queue() diff --git a/py_spring_core/exception_handler/decorator.py b/py_spring_core/exception_handler/decorator.py new file mode 100644 index 0000000..3d60b4e --- /dev/null +++ b/py_spring_core/exception_handler/decorator.py @@ -0,0 +1,18 @@ + + +from functools import wraps +from typing import Any, Callable, Type + + +from py_spring_core.exception_handler.exception_handler_registry import ExceptionHandlerRegistry + + +def ExceptionHandler(exception_cls: Type[Exception]) -> Callable[[Callable[[Exception], Any]], Callable]: + def decorator(func: Callable[[Exception], Any]) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + return func(*args, **kwargs) + ExceptionHandlerRegistry.register(exception_cls, wrapper) + return wrapper + + return decorator \ No newline at end of file diff --git a/py_spring_core/exception_handler/exception_handler_registry.py b/py_spring_core/exception_handler/exception_handler_registry.py new file mode 100644 index 0000000..2861975 --- /dev/null +++ b/py_spring_core/exception_handler/exception_handler_registry.py @@ -0,0 +1,25 @@ + + +from typing import Any, Callable, Type, TypeVar + +from loguru import logger + +E = TypeVar('E', bound=Exception) + +class ExceptionHandlerRegistry: + _handlers: dict[str, Callable[[Any], Any]] = {} + + @classmethod + def register(cls, exception_cls: Type[E], handler: Callable[[E], Any]): + key = exception_cls.__name__ + logger.debug(f"Registering exception handler for {key}: {handler.__name__}") + if key in cls._handlers: + error_message = f"Exception handler for {exception_cls} already registered" + logger.error(error_message) + raise RuntimeError(error_message) + + cls._handlers[exception_cls.__name__] = handler + + @classmethod + def get_handler(cls, exception_cls: Type[E]) -> Callable[[E], Any]: + return cls._handlers[exception_cls.__name__] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 45bf6c5..f896924 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,55 +1,52 @@ [project] name = "py_spring_core" dynamic = ["version"] -description = "Default template for PDM package" +description = "PySpring is a Python web framework inspired by Spring Boot, combining FastAPI, and Pydantic for building scalable web applications with auto dependency injection, configuration management, and a web server." authors = [ {name = "William Chen", email = "OW6201231@gmail.com"}, ] dependencies = [ - "annotated-types==0.7.0", - "anyio==4.4.0", - "certifi==2024.7.4", - "click==8.1.7", - "dnspython==2.6.1", - "email-validator==2.2.0", - "exceptiongroup==1.2.2", - "fastapi==0.112.0", - "fastapi-cli==0.0.5", - "greenlet==3.0.3", - "h11==0.14.0", - "httpcore==1.0.5", - "httptools==0.6.1", - "httpx==0.27.0", - "idna==3.7", - "itsdangerous==2.2.0", - "Jinja2==3.1.4", - "loguru==0.7.2", - "markdown-it-py==3.0.0", - "MarkupSafe==2.1.5", - "mdurl==0.1.2", - "orjson==3.10.7", - "pydantic==2.8.2", - "pydantic-extra-types==2.9.0", - "pydantic-settings==2.4.0", - "pydantic-core==2.20.1", - "Pygments==2.18.0", - "python-dotenv==1.0.1", - "python-multipart==0.0.9", - "PyYAML==6.0.2", - "rich==13.7.1", - "shellingham==1.5.4", - "sniffio==1.3.1", - "starlette==0.37.2", + "annotated-types>=0.7.0", + "anyio>=4.4.0", + "certifi>=2023.7.22", + "click>=8.1.7", + "dnspython>=2.6.1", + "email-validator>=2.2.0", + "exceptiongroup>=1.2.2", + "fastapi>=0.115.12", + "fastapi-cli>=0.0.5", + "greenlet>=3.0.3", + "h11>=0.16.0", + "httpcore>=1.0.9", + "httptools>=0.6.1", + "httpx>=0.27.0", + "idna>=3.7", + "itsdangerous>=2.2.0", + "Jinja2>=3.1.6", + "loguru>=0.7.2", + "markdown-it-py>=3.0.0", + "MarkupSafe>=2.1.5", + "mdurl>=0.1.2", + "orjson>=3.10.7", + "Pygments>=2.18.0", + "python-dotenv>=1.0.1", + "python-multipart>=0.0.18", + "PyYAML>=6.0.2", + "rich>=13.7.1", + "shellingham>=1.5.4", + "sniffio>=1.3.1", + "starlette>=0.40.0,<0.47.0", "typer>=0.12.5", - "typing-extensions==4.12.2", - "ujson==5.10.0", - "uvicorn==0.30.5", - "uvloop==0.19.0", - "watchfiles==0.23.0", - "websockets==12.0", + "typing-extensions>=4.12.2", + "ujson>=5.10.0", + "uvicorn>=0.30.5", + "uvloop>=0.19.0", + "watchfiles>=0.23.0", + "websockets>=12.0", "cachetools>=5.5.0", + "pydantic>=2.11.7" ] -requires-python = ">=3.10" +requires-python = ">=3.10,<3.13" readme = "README.md" license = {text = "MIT"} @@ -76,4 +73,8 @@ dev = [ "isort>=5.13.2", "pytest>=8.3.2", "pytest-mock>=3.14.0", + "pytest-asyncio>=1.1.0", + "types-PyYAML>=6.0.12.20240917", + "types-cachetools>=5.5.0.20240820", + "mypy>=1.11.2" ] diff --git a/tests/test_application_config.py b/tests/test_application_config.py new file mode 100644 index 0000000..248f883 --- /dev/null +++ b/tests/test_application_config.py @@ -0,0 +1,146 @@ +import pytest +from pydantic import ValidationError + +from py_spring_core.core.application.application_config import ( + ApplicationConfig, + ShutdownConfig, + ServerConfig, + LoguruConfig, +) + + +class TestShutdownConfig: + """Test suite for the ShutdownConfig class.""" + + def test_shutdown_config_defaults(self): + """Test that ShutdownConfig has correct default values.""" + config = ShutdownConfig() + + assert config.timeout_seconds == 30.0 + assert config.enabled is True + + def test_shutdown_config_custom_values(self): + """Test ShutdownConfig with custom values.""" + config = ShutdownConfig(timeout_seconds=60.0, enabled=False) + + assert config.timeout_seconds == 60.0 + assert config.enabled is False + + def test_shutdown_config_validation(self): + """Test ShutdownConfig validation.""" + # Valid timeout values + config = ShutdownConfig(timeout_seconds=0.1) + assert config.timeout_seconds == 0.1 + + config = ShutdownConfig(timeout_seconds=1000.0) + assert config.timeout_seconds == 1000.0 + + def test_shutdown_config_type_validation(self): + """Test that ShutdownConfig validates types correctly.""" + # Should accept float or int for timeout_seconds + config = ShutdownConfig(timeout_seconds=30) + assert config.timeout_seconds == 30.0 + + config = ShutdownConfig(timeout_seconds=30.5) + assert config.timeout_seconds == 30.5 + + # Should accept boolean for enabled + config = ShutdownConfig(enabled=True) + assert config.enabled is True + + config = ShutdownConfig(enabled=False) + assert config.enabled is False + + def test_shutdown_config_serialization(self): + """Test that ShutdownConfig can be serialized/deserialized.""" + config = ShutdownConfig(timeout_seconds=45.0, enabled=True) + + # Test model dump + config_dict = config.model_dump() + expected = {"timeout_seconds": 45.0, "enabled": True} + assert config_dict == expected + + # Test reconstruction from dict + new_config = ShutdownConfig(**config_dict) + assert new_config.timeout_seconds == 45.0 + assert new_config.enabled is True + + +class TestApplicationConfigWithShutdown: + """Test suite for ApplicationConfig with shutdown configuration.""" + + def test_application_config_with_default_shutdown(self): + """Test that ApplicationConfig includes default shutdown config.""" + config = ApplicationConfig( + app_src_target_dir="./src", + server_config=ServerConfig(host="localhost", port=8000), + properties_file_path="./app.properties", + loguru_config=LoguruConfig(log_file_path="./logs/app.log") + ) + + # Should have default shutdown config + assert config.shutdown_config is not None + assert config.shutdown_config.timeout_seconds == 30.0 + assert config.shutdown_config.enabled is True + + def test_application_config_with_custom_shutdown(self): + """Test ApplicationConfig with custom shutdown configuration.""" + custom_shutdown = ShutdownConfig(timeout_seconds=60.0, enabled=False) + + config = ApplicationConfig( + app_src_target_dir="./src", + server_config=ServerConfig(host="localhost", port=8000), + properties_file_path="./app.properties", + loguru_config=LoguruConfig(log_file_path="./logs/app.log"), + shutdown_config=custom_shutdown + ) + + assert config.shutdown_config.timeout_seconds == 60.0 + assert config.shutdown_config.enabled is False + + def test_application_config_serialization_with_shutdown(self): + """Test ApplicationConfig serialization includes shutdown config.""" + config = ApplicationConfig( + app_src_target_dir="./src", + server_config=ServerConfig(host="localhost", port=8000), + properties_file_path="./app.properties", + loguru_config=LoguruConfig(log_file_path="./logs/app.log"), + shutdown_config=ShutdownConfig(timeout_seconds=45.0, enabled=True) + ) + + config_dict = config.model_dump() + + assert "shutdown_config" in config_dict + assert config_dict["shutdown_config"]["timeout_seconds"] == 45.0 + assert config_dict["shutdown_config"]["enabled"] is True + + def test_application_config_from_dict_with_shutdown(self): + """Test ApplicationConfig reconstruction from dict with shutdown config.""" + config_dict = { + "app_src_target_dir": "./src", + "server_config": {"host": "localhost", "port": 8000}, + "properties_file_path": "./app.properties", + "loguru_config": {"log_file_path": "./logs/app.log", "log_level": "INFO"}, + "shutdown_config": {"timeout_seconds": 25.0, "enabled": False} + } + + config = ApplicationConfig(**config_dict) + + assert config.app_src_target_dir == "./src" + assert config.shutdown_config.timeout_seconds == 25.0 + assert config.shutdown_config.enabled is False + + def test_application_config_without_shutdown_config_in_dict(self): + """Test that ApplicationConfig uses default shutdown when not provided.""" + config_dict = { + "app_src_target_dir": "./src", + "server_config": {"host": "localhost", "port": 8000}, + "properties_file_path": "./app.properties", + "loguru_config": {"log_file_path": "./logs/app.log", "log_level": "INFO"} + } + + config = ApplicationConfig(**config_dict) + + # Should use default shutdown config + assert config.shutdown_config.timeout_seconds == 30.0 + assert config.shutdown_config.enabled is True \ No newline at end of file diff --git a/tests/test_application_context.py b/tests/test_application_context.py index 1442199..ae159c6 100644 --- a/tests/test_application_context.py +++ b/tests/test_application_context.py @@ -1,20 +1,25 @@ import pytest +from fastapi import FastAPI from py_spring_core.core.application.context.application_context import ( ApplicationContext, ApplicationContextConfig, ) -from py_spring_core.core.entities.bean_collection import BeanCollection -from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection +from py_spring_core.core.entities.component.component import Component from py_spring_core.core.entities.controllers.rest_controller import RestController from py_spring_core.core.entities.properties.properties import Properties class TestApplicationContext: @pytest.fixture - def app_context(self): + def server(self) -> FastAPI: + return FastAPI() + + @pytest.fixture + def app_context(self, server: FastAPI): config = ApplicationContextConfig(properties_path="") - return ApplicationContext(config) + return ApplicationContext(config, server=server) def test_register_entities_correctly(self, app_context: ApplicationContext): class TestComponent(Component): ... @@ -32,21 +37,28 @@ class TestProperties(Properties): app_context.register_properties(TestProperties) assert ( - "TestComponent" in app_context.component_cls_container - and app_context.component_cls_container["TestComponent"] == TestComponent + "TestComponent" in app_context.container_manager.component_classes + and app_context.container_manager.component_classes["TestComponent"] + == TestComponent ) assert ( - "TestController" in app_context.controller_cls_container - and app_context.controller_cls_container["TestController"] == TestController + "TestController" in app_context.container_manager.controller_classes + and app_context.container_manager.controller_classes["TestController"] + == TestController ) assert ( - "TestBeanCollection" in app_context.bean_collection_cls_container - and app_context.bean_collection_cls_container["TestBeanCollection"] + "TestBeanCollection" + in app_context.container_manager.bean_collection_classes + and app_context.container_manager.bean_collection_classes[ + "TestBeanCollection" + ] == TestBeanCollection ) assert ( - "test_properties" in app_context.properties_cls_container - and app_context.properties_cls_container["test_properties"] + "test_properties" in app_context.container_manager.properties_classes + and app_context.container_manager.properties_classes[ + "test_properties" + ] == TestProperties ) @@ -65,10 +77,17 @@ class TestProperties(Properties): app_context.register_bean_collection(TestBeanCollection) app_context.register_properties(TestProperties) - assert "TestComponent" in app_context.component_cls_container - assert "TestController" in app_context.controller_cls_container - assert "TestBeanCollection" in app_context.bean_collection_cls_container - assert "test_properties" in app_context.properties_cls_container + assert "TestComponent" in app_context.container_manager.component_classes + assert ( + "TestController" in app_context.container_manager.controller_classes + ) + assert ( + "TestBeanCollection" + in app_context.container_manager.bean_collection_classes + ) + assert ( + "test_properties" in app_context.container_manager.properties_classes + ) def test_register_invalid_entities_raises_error( self, app_context: ApplicationContext @@ -123,25 +142,25 @@ class TestProperties(Properties): # Test retrieving singleton components component_instance = TestComponent() - app_context.singleton_component_instance_container["TestComponent"] = ( - component_instance - ) - retrieved_component = app_context.get_component(TestComponent) + app_context.container_manager.component_instances[ + "TestComponent" + ] = component_instance + retrieved_component = app_context.get_component(TestComponent, None) assert retrieved_component is component_instance # Test retrieving singleton beans bean_instance = TestBeanCollection() - app_context.singleton_bean_instance_container["TestBeanCollection"] = ( - bean_instance - ) - retrieved_bean = app_context.get_bean(TestBeanCollection) + app_context.container_manager.bean_instances[ + "TestBeanCollection" + ] = bean_instance + retrieved_bean = app_context.get_bean(TestBeanCollection, None) assert retrieved_bean is bean_instance # Test retrieving singleton properties properties_instance = TestProperties() - app_context.singleton_properties_instance_container["test_properties"] = ( - properties_instance - ) + app_context.container_manager.properties_instances[ + "test_properties" + ] = properties_instance retrieved_properties = app_context.get_properties(TestProperties) assert retrieved_properties is properties_instance diff --git a/tests/test_bean_collection.py b/tests/test_bean_collection.py index 6a35c21..d8f5e9e 100644 --- a/tests/test_bean_collection.py +++ b/tests/test_bean_collection.py @@ -1,7 +1,7 @@ import pytest -from py_spring_core.core.entities.bean_collection import BeanCollection, BeanView -from py_spring_core.core.entities.component import Component +from py_spring_core.core.entities.bean_collection.bean_collection import BeanCollection, BeanView +from py_spring_core.core.entities.component.component import Component class TestBeanView: diff --git a/tests/test_component_features.py b/tests/test_component_features.py new file mode 100644 index 0000000..000d7f8 --- /dev/null +++ b/tests/test_component_features.py @@ -0,0 +1,206 @@ +from typing import Annotated + +import pytest +from fastapi import FastAPI + +from py_spring_core.core.application.context.application_context import ( + ApplicationContext, + ApplicationContextConfig, +) +from py_spring_core.core.entities.component.component import Component, ComponentScope + + +class TestComponentFeatures: + """Test suite for component features including primary components, qualifiers, and registration validation.""" + + @pytest.fixture + def server(self) -> FastAPI: + return FastAPI() + + @pytest.fixture + def app_context(self, server: FastAPI): + """Fixture that provides a fresh ApplicationContext instance for each test.""" + config = ApplicationContextConfig(properties_path="") + return ApplicationContext(config, server=server) + + def test_qualifier_based_injection(self, app_context: ApplicationContext): + """ + Test the qualifier-based dependency injection mechanism. + + This test verifies that: + 1. Multiple implementations of an abstract component can coexist + 2. Specific implementations can be injected using qualifiers + 3. Both primary and non-primary components can be injected using qualifiers + 4. The correct implementation is injected for each qualifier + + The test creates an abstract service with two implementations and verifies + that each can be injected into a consumer using appropriate qualifiers. + """ + + # Define abstract base class + class AbstractService(Component): + class Config: + is_primary = True + scope = ComponentScope.Singleton + + def process(self) -> str: + raise NotImplementedError() + + # Define implementations + class ServiceA(AbstractService): + class Config: + is_primary = True + name = "ServiceA" + scope = ComponentScope.Singleton + + def process(self) -> str: + return "Service A processing" + + class ServiceB(AbstractService): + class Config: + is_primary = False + name = "ServiceB" + scope = ComponentScope.Singleton + + def process(self) -> str: + return "Service B processing" + + # Register implementations + app_context.register_component(ServiceA) + app_context.register_component(ServiceB) + app_context.init_ioc_container() + + # Test qualifier-based injection + class ServiceConsumer(Component): + service_a: Annotated[AbstractService, "ServiceA"] + service_b: Annotated[AbstractService, "ServiceB"] + + def post_construct(self) -> None: + assert isinstance(self.service_a, ServiceA) + assert isinstance(self.service_b, ServiceB) + assert self.service_a.process() == "Service A processing" + assert self.service_b.process() == "Service B processing" + + app_context.register_component(ServiceConsumer) + app_context.init_ioc_container() # Initialize the consumer component + app_context.inject_dependencies_for_app_entities() + + def test_duplicate_component_registration(self, app_context: ApplicationContext): + """ + Test the prevention of duplicate component registration. + + This test verifies that: + 1. A component can only be registered once + 2. Attempting to register the same component again doesn't raise an error (silent skip) + 3. The component is only registered once in the container + + The test attempts to register the same component twice and verifies + that it's handled gracefully without errors. + """ + + # Define a component + class TestService(Component): + class Config: + name = "TestService" + scope = ComponentScope.Singleton + + def process(self) -> str: + return "Test service processing" + + # Register the same component multiple times + app_context.register_component(TestService) + initial_count = len(app_context.container_manager.component_classes) + app_context.register_component(TestService) + app_context.register_component(TestService) + final_count = len(app_context.container_manager.component_classes) + + # Verify the component is registered only once + assert initial_count == final_count == 1 + assert "TestService" in app_context.container_manager.component_classes + + def test_component_name_override(self, app_context: ApplicationContext): + """ + Test the ability to override component names during registration. + + This test verifies that: + 1. Components can be registered with custom names + 2. The custom name is correctly stored in the component container + 3. The component can be retrieved using the custom name + + The test registers a component with a custom name and verifies + that it is correctly stored in the container. + """ + + # Define component with custom name + class TestService(Component): + class Config: + name = "CustomServiceName" + scope = ComponentScope.Singleton + + def process(self) -> str: + return "Test service processing" + + # Register component + app_context.register_component(TestService) + app_context.init_ioc_container() + + # Check if component is registered with custom name + assert ( + "CustomServiceName" in app_context.container_manager.component_classes + ) + component_instance = ( + app_context.container_manager.component_classes["CustomServiceName"] + ) + + def test_qualifier_with_invalid_component(self, app_context: ApplicationContext): + """ + Test error handling for invalid qualifier usage. + + This test verifies that: + 1. Attempting to inject a component with an invalid qualifier raises an error + 2. The error message clearly indicates the invalid qualifier + 3. The error occurs during dependency injection + + The test attempts to inject a component using a non-existent qualifier + and verifies that an appropriate error is raised. + """ + + # Define abstract base class + class AbstractService(Component): + class Config: + is_primary = True + scope = ComponentScope.Singleton + + def process(self) -> str: + raise NotImplementedError() + + # Define implementation + class TestService(AbstractService): + class Config: + is_primary = True + name = "TestService" + scope = ComponentScope.Singleton + + def process(self) -> str: + return "Test service processing" + + # Register implementation + app_context.register_component(TestService) + app_context.init_ioc_container() + + # Test injection with invalid qualifier + class ServiceConsumer(Component): + service: Annotated[AbstractService, "NonExistentService"] + + def post_construct(self) -> None: + pass + + app_context.register_component(ServiceConsumer) + app_context.init_ioc_container() # Initialize the consumer component + + # Attempting to inject with invalid qualifier should raise error + with pytest.raises( + ValueError, + match="\\[DEPENDENCY INJECTION FAILED\\] Fail to inject dependency for attribute: service with dependency: AbstractService with qualifier: NonExistentService", + ): + app_context.inject_dependencies_for_app_entities() diff --git a/tests/test_entity_provider.py b/tests/test_entity_provider.py index fad6f24..3279053 100644 --- a/tests/test_entity_provider.py +++ b/tests/test_entity_provider.py @@ -1,4 +1,5 @@ import pytest +from fastapi import FastAPI from py_spring_core.core.application.context.application_context import ( ApplicationContext, @@ -7,8 +8,8 @@ from py_spring_core.core.application.context.application_context_config import ( ApplicationContextConfig, ) -from py_spring_core.core.entities.component import Component -from py_spring_core.core.entities.entity_provider import EntityProvider +from py_spring_core.core.entities.component.component import Component +from py_spring_core.core.entities.entity_provider.entity_provider import EntityProvider class TestComponent(Component): ... @@ -19,12 +20,18 @@ class TestEntityProvider: def test_entity_provider(self): return EntityProvider(depends_on=[TestComponent]) + @pytest.fixture + def server(self) -> FastAPI: + return FastAPI() + @pytest.fixture def test_app_context( - self, test_entity_provider: EntityProvider + self, test_entity_provider: EntityProvider, server: FastAPI ) -> ApplicationContext: - app_context = ApplicationContext(ApplicationContextConfig(properties_path="")) - app_context.register_entity_provider(test_entity_provider) + app_context = ApplicationContext( + ApplicationContextConfig(properties_path=""), server=server + ) + app_context.providers.append(test_entity_provider) return app_context def test_did_raise_error_when_no_depends_on_is_provided( diff --git a/tests/test_framework_utils.py b/tests/test_framework_utils.py new file mode 100644 index 0000000..8c83ce6 --- /dev/null +++ b/tests/test_framework_utils.py @@ -0,0 +1,226 @@ +import pytest +from abc import ABC, abstractmethod +from typing import Type, Any + +from py_spring_core.core.utils import get_unimplemented_abstract_methods + + +class TestFrameworkUtils: + """Test suite for framework utility functions.""" + + def test_get_unimplemented_abstract_methods_with_concrete_class(self): + """Test that fully implemented classes return empty set.""" + + class AbstractBase(ABC): + @abstractmethod + def method_a(self) -> None: + pass + + @abstractmethod + def method_b(self) -> str: + pass + + class ConcreteImpl(AbstractBase): + def method_a(self) -> None: + pass + + def method_b(self) -> str: + return "implemented" + + unimplemented = get_unimplemented_abstract_methods(ConcreteImpl) + assert unimplemented == set() + + def test_get_unimplemented_abstract_methods_with_partial_implementation(self): + """Test that partially implemented classes return missing methods.""" + + class AbstractBase(ABC): + @abstractmethod + def method_a(self) -> None: + pass + + @abstractmethod + def method_b(self) -> str: + pass + + @abstractmethod + def method_c(self) -> int: + pass + + class PartialImpl(AbstractBase): + def method_a(self) -> None: + pass + + # method_b and method_c are not implemented + + unimplemented = get_unimplemented_abstract_methods(PartialImpl) + assert unimplemented == {"method_b", "method_c"} + + def test_get_unimplemented_abstract_methods_with_no_implementation(self): + """Test that classes with no implementations return all abstract methods.""" + + class AbstractBase(ABC): + @abstractmethod + def method_a(self) -> None: + pass + + @abstractmethod + def method_b(self) -> str: + pass + + class NoImpl(AbstractBase): + # No methods implemented + pass + + unimplemented = get_unimplemented_abstract_methods(NoImpl) + assert unimplemented == {"method_a", "method_b"} + + def test_get_unimplemented_abstract_methods_with_multiple_inheritance(self): + """Test with multiple inheritance from abstract classes.""" + + class AbstractA(ABC): + @abstractmethod + def method_a(self) -> None: + pass + + class AbstractB(ABC): + @abstractmethod + def method_b(self) -> str: + pass + + class MultipleInheritance(AbstractA, AbstractB): + def method_a(self) -> None: + pass + # method_b is not implemented + + unimplemented = get_unimplemented_abstract_methods(MultipleInheritance) + assert unimplemented == {"method_b"} + + def test_get_unimplemented_abstract_methods_with_inheritance_chain(self): + """Test with inheritance chain where parent implements some methods.""" + + class AbstractBase(ABC): + @abstractmethod + def method_a(self) -> None: + pass + + @abstractmethod + def method_b(self) -> str: + pass + + @abstractmethod + def method_c(self) -> int: + pass + + class PartialParent(AbstractBase): + def method_a(self) -> None: + pass + # method_b and method_c still abstract + + class ChildImpl(PartialParent): + def method_b(self) -> str: + return "implemented" + # method_c still not implemented + + unimplemented = get_unimplemented_abstract_methods(ChildImpl) + assert unimplemented == {"method_c"} + + def test_get_unimplemented_abstract_methods_with_no_abstract_methods(self): + """Test with class that has no abstract methods.""" + + class NonAbstractBase(ABC): + def regular_method(self) -> None: + pass + + class RegularClass(NonAbstractBase): + def another_method(self) -> str: + return "normal" + + unimplemented = get_unimplemented_abstract_methods(RegularClass) + assert unimplemented == set() + + def test_get_unimplemented_abstract_methods_type_error_non_class(self): + """Test that function raises TypeError for non-class types.""" + + with pytest.raises(TypeError, match="Expected a class type"): + get_unimplemented_abstract_methods("not a class") # type: ignore + + with pytest.raises(TypeError, match="Expected a class type"): + get_unimplemented_abstract_methods(42) # type: ignore + + def test_get_unimplemented_abstract_methods_type_error_non_abc(self): + """Test that function raises TypeError for non-ABC classes.""" + + class RegularClass: + def some_method(self) -> None: + pass + + with pytest.raises(TypeError, match="Expected a subclass of abc.ABC"): + get_unimplemented_abstract_methods(RegularClass) + + def test_get_unimplemented_abstract_methods_with_property_abstracts(self): + """Test with abstract properties.""" + + class AbstractWithProperty(ABC): + @property + @abstractmethod + def abstract_property(self) -> str: + pass + + @abstractmethod + def abstract_method(self) -> None: + pass + + class PartialPropertyImpl(AbstractWithProperty): + @property + def abstract_property(self) -> str: + return "implemented" + # abstract_method not implemented + + unimplemented = get_unimplemented_abstract_methods(PartialPropertyImpl) + # Properties might be included in the abstract methods set, so we check that abstract_method is there + # and abstract_property is not (since it's implemented) + assert "abstract_method" in unimplemented + assert len([method for method in unimplemented if "abstract_property" not in method]) >= 1 + + def test_get_unimplemented_abstract_methods_with_staticmethod_classmethod(self): + """Test with abstract static and class methods.""" + + class AbstractWithMethods(ABC): + @staticmethod + @abstractmethod + def abstract_static() -> str: + pass + + @classmethod + @abstractmethod + def abstract_class(cls) -> str: + pass + + @abstractmethod + def abstract_instance(self) -> None: + pass + + class PartialMethodImpl(AbstractWithMethods): + @staticmethod + def abstract_static() -> str: + return "static implemented" + + # abstract_class and abstract_instance not implemented + + unimplemented = get_unimplemented_abstract_methods(PartialMethodImpl) + assert unimplemented == {"abstract_class", "abstract_instance"} + + def test_get_unimplemented_abstract_methods_real_world_example(self): + """Test with a real-world like example similar to GracefulShutdownHandler.""" + + from py_spring_core.core.interfaces.graceful_shutdown_handler import GracefulShutdownHandler + + class IncompleteShutdownHandler(GracefulShutdownHandler): + def on_shutdown(self, shutdown_type) -> None: + pass + # on_timeout and on_error not implemented + + unimplemented = get_unimplemented_abstract_methods(IncompleteShutdownHandler) + assert "on_timeout" in unimplemented + assert "on_error" in unimplemented + assert "on_shutdown" not in unimplemented \ No newline at end of file diff --git a/tests/test_graceful_shutdown_handler.py b/tests/test_graceful_shutdown_handler.py new file mode 100644 index 0000000..76fb29e --- /dev/null +++ b/tests/test_graceful_shutdown_handler.py @@ -0,0 +1,340 @@ +import signal +import threading +import time +from typing import Optional +from unittest.mock import MagicMock, patch +import pytest + +from py_spring_core.core.interfaces.graceful_shutdown_handler import ( + GracefulShutdownHandler, + ShutdownType, +) + + +class TestGracefulShutdownHandler: + """Test suite for the graceful shutdown handler functionality.""" + + def setup_method(self): + """Reset any signal handlers before each test.""" + signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGTERM, signal.SIG_DFL) + + def test_graceful_shutdown_handler_interface(self): + """Test that GracefulShutdownHandler enforces abstract methods.""" + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + # Should not be able to instantiate abstract class directly + GracefulShutdownHandler(timeout_seconds=30.0, timeout_enabled=True) # type: ignore + + def test_concrete_implementation_creation(self): + """Test that a concrete implementation can be created successfully.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + # Should be able to create concrete implementation + handler = TestShutdownHandler(timeout_seconds=5.0, timeout_enabled=True) + assert handler.get_timeout_seconds() == 5.0 + assert handler.is_timeout_enabled() is True + assert handler.is_shutdown() is False + assert handler.get_type() is None + + def test_sigint_handling(self): + """Test SIGINT signal handling.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=1.0, timeout_enabled=False) + + # Simulate SIGINT + handler._handle_sigint(signal.SIGINT, None) + + assert handler.is_shutdown() is True + assert handler.get_type() == ShutdownType.MANUAL + assert len(handler.shutdown_calls) == 1 + assert handler.shutdown_calls[0] == ShutdownType.MANUAL + + def test_sigterm_handling(self): + """Test SIGTERM signal handling.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=1.0, timeout_enabled=False) + + # Simulate SIGTERM + handler._handle_sigterm(signal.SIGTERM, None) + + assert handler.is_shutdown() is True + assert handler.get_type() == ShutdownType.SIGTERM + assert len(handler.shutdown_calls) == 1 + assert handler.shutdown_calls[0] == ShutdownType.SIGTERM + + def test_duplicate_signal_handling(self): + """Test that duplicate signals are ignored.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=1.0, timeout_enabled=False) + + # First signal + handler._handle_sigint(signal.SIGINT, None) + assert len(handler.shutdown_calls) == 1 + + # Second signal should be ignored + handler._handle_sigint(signal.SIGINT, None) + assert len(handler.shutdown_calls) == 1 # Still only one call + + @patch('os._exit') + def test_timeout_functionality(self, mock_exit): + """Test shutdown timeout functionality.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + # Test with timeout enabled + handler = TestShutdownHandler(timeout_seconds=0.1, timeout_enabled=True) + + # Trigger shutdown + handler._handle_sigint(signal.SIGINT, None) + + # Wait for timeout to trigger + time.sleep(0.2) + + # Should have called timeout and os._exit + assert len(handler.timeout_calls) == 1 + mock_exit.assert_called_once_with(0) + + def test_timeout_disabled(self): + """Test that timeout doesn't trigger when disabled.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + # Test with timeout disabled + handler = TestShutdownHandler(timeout_seconds=0.1, timeout_enabled=False) + + # Trigger shutdown + handler._handle_sigint(signal.SIGINT, None) + + # Wait for potential timeout + time.sleep(0.2) + + # Should not have called timeout + assert len(handler.timeout_calls) == 0 + + def test_complete_shutdown(self): + """Test shutdown completion functionality.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=10.0, timeout_enabled=True) + + # Trigger shutdown + handler._handle_sigint(signal.SIGINT, None) + + # Complete shutdown before timeout + handler.complete_shutdown() + + # Wait to ensure timeout doesn't trigger + time.sleep(0.1) + + # Should not have called timeout since shutdown was completed + assert len(handler.timeout_calls) == 0 + + def test_error_handling_in_signals(self): + """Test error handling in signal handlers.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + raise RuntimeError("Test error in shutdown") + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=1.0, timeout_enabled=False) + + # Signal should trigger error handling + handler._handle_sigint(signal.SIGINT, None) + + assert len(handler.error_calls) == 1 + assert isinstance(handler.error_calls[0], RuntimeError) + + def test_shutdown_elapsed_time(self): + """Test shutdown elapsed time tracking.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + pass + + def on_timeout(self) -> None: + pass + + def on_error(self, error: Exception) -> None: + pass + + # Use a longer timeout to ensure timer starts + handler = TestShutdownHandler(timeout_seconds=30.0, timeout_enabled=True) + + # Before shutdown + assert handler.get_shutdown_elapsed_time() is None + + # Trigger shutdown - this should start the timer and set _shutdown_start_time + handler._handle_sigint(signal.SIGINT, None) + + # Small delay + time.sleep(0.1) + + # Should have elapsed time + elapsed = handler.get_shutdown_elapsed_time() + assert elapsed is not None + assert elapsed >= 0.1 + + def test_shutdown_types_enum(self): + """Test ShutdownType enum values.""" + assert ShutdownType.MANUAL is not None + assert ShutdownType.SIGTERM is not None + assert ShutdownType.TIMEOUT is not None + assert ShutdownType.ERROR is not None + assert ShutdownType.UNKNOWN is not None + + # Ensure all types are unique + types = [ShutdownType.MANUAL, ShutdownType.SIGTERM, ShutdownType.TIMEOUT, + ShutdownType.ERROR, ShutdownType.UNKNOWN] + assert len(set(types)) == len(types) + + @patch('os._exit') + def test_timeout_force_exit(self, mock_exit): + """Test that timeout eventually forces exit.""" + + class TestShutdownHandler(GracefulShutdownHandler): + def __init__(self, timeout_seconds: float, timeout_enabled: bool): + self.shutdown_calls = [] + self.timeout_calls = [] + self.error_calls = [] + super().__init__(timeout_seconds, timeout_enabled) + + def on_shutdown(self, shutdown_type: ShutdownType) -> None: + self.shutdown_calls.append(shutdown_type) + + def on_timeout(self) -> None: + self.timeout_calls.append("timeout") + + def on_error(self, error: Exception) -> None: + self.error_calls.append(error) + + handler = TestShutdownHandler(timeout_seconds=0.1, timeout_enabled=True) + + # Trigger shutdown + handler._handle_sigint(signal.SIGINT, None) + + # Wait for timeout to trigger + time.sleep(0.2) + + # Should have called os._exit + mock_exit.assert_called_once_with(0) \ No newline at end of file diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..979a0d1 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,886 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from fastapi import FastAPI, Request, Response +from fastapi.testclient import TestClient +from starlette.middleware.base import BaseHTTPMiddleware + +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.entities.middlewares.middleware_registry import ( + MiddlewareRegistry, + MiddlewareConfiguration, +) + + +class TestMiddleware: + """Test suite for the Middleware base class.""" + + @pytest.fixture + def mock_request(self): + """Fixture that provides a mock FastAPI request.""" + request = Mock(spec=Request) + request.method = "GET" + request.url = "http://test.com/api" + return request + + @pytest.fixture + def mock_call_next(self): + """Fixture that provides a mock call_next function.""" + return AsyncMock() + + def test_middleware_inherits_from_base_http_middleware(self): + """ + Test that Middleware class inherits from BaseHTTPMiddleware. + + This test verifies that: + 1. Middleware is a subclass of BaseHTTPMiddleware + 2. The inheritance relationship is correctly established + """ + assert issubclass(Middleware, BaseHTTPMiddleware) + + def test_middleware_is_abstract(self): + """ + Test that Middleware class is abstract and cannot be instantiated directly. + + This test verifies that: + 1. Middleware is an abstract base class + 2. Attempting to instantiate it directly raises an error + """ + # Test that Middleware is abstract by checking it has abstract methods + assert hasattr(Middleware, "process_request") + assert Middleware.process_request.__isabstractmethod__ + + def test_process_request_is_abstract(self): + """ + Test that process_request method is abstract and must be implemented. + + This test verifies that: + 1. process_request is an abstract method + 2. Subclasses must implement this method + """ + + # Create a concrete subclass without implementing process_request + class ConcreteMiddleware(Middleware): + pass + + # Test that the class is abstract by checking it has abstract methods + assert hasattr(ConcreteMiddleware, "process_request") + # The method should still be abstract since it wasn't implemented + assert ConcreteMiddleware.process_request.__isabstractmethod__ + + @pytest.mark.asyncio + async def test_dispatch_continues_when_process_request_returns_none( + self, mock_request, mock_call_next + ): + """ + Test that dispatch continues to next middleware when process_request returns None. + + This test verifies that: + 1. When process_request returns None, dispatch continues to call_next + 2. The call_next function is called with the correct request + 3. The response from call_next is returned + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = TestMiddleware(app=Mock()) + result = await middleware.dispatch(mock_request, mock_call_next) + + mock_call_next.assert_called_once_with(mock_request) + assert result == expected_response + + @pytest.mark.asyncio + async def test_dispatch_returns_response_when_process_request_returns_response( + self, mock_request, mock_call_next + ): + """ + Test that dispatch returns response directly when process_request returns a response. + + This test verifies that: + 1. When process_request returns a Response, dispatch returns it directly + 2. call_next is not called when process_request returns a response + 3. The response from process_request is returned unchanged + """ + middleware_response = Response(content="middleware response", status_code=403) + + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return middleware_response + + middleware = TestMiddleware(app=Mock()) + result = await middleware.dispatch(mock_request, mock_call_next) + + mock_call_next.assert_not_called() + assert result == middleware_response + + @pytest.mark.asyncio + async def test_dispatch_passes_request_to_process_request( + self, mock_request, mock_call_next + ): + """ + Test that dispatch passes the request to process_request method. + + This test verifies that: + 1. The request object is correctly passed to process_request + 2. The process_request method receives the exact same request object + """ + received_request = None + + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + nonlocal received_request + received_request = request + return None + + middleware = TestMiddleware(app=Mock()) + await middleware.dispatch(mock_request, mock_call_next) + + assert received_request == mock_request + + def test_should_skip_default_returns_false(self, mock_request): + """ + Test that should_skip method returns False by default. + + This test verifies that: + 1. The default implementation of should_skip returns False + 2. This allows the middleware to process all requests by default + """ + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = TestMiddleware(app=Mock()) + result = middleware.should_skip(mock_request) + assert result is False + + def test_should_skip_can_be_overridden(self, mock_request): + """ + Test that should_skip method can be overridden in subclasses. + + This test verifies that: + 1. Subclasses can override should_skip to provide custom skip logic + 2. The overridden method is called with the correct request parameter + """ + class SkippingMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + return request.method == "GET" + + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = SkippingMiddleware(app=Mock()) + result = middleware.should_skip(mock_request) + assert result is True + + @pytest.mark.asyncio + async def test_dispatch_skips_middleware_when_should_skip_returns_true( + self, mock_request, mock_call_next + ): + """ + Test that dispatch skips middleware processing when should_skip returns True. + + This test verifies that: + 1. When should_skip returns True, process_request is not called + 2. The request is passed directly to call_next + 3. The response from call_next is returned + """ + expected_response = Response(content="skipped response", status_code=200) + mock_call_next.return_value = expected_response + + class SkippingMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + return True + + async def process_request(self, request: Request) -> Response | None: + # This should never be called when should_skip returns True + raise AssertionError("process_request should not be called") + + middleware = SkippingMiddleware(app=Mock()) + result = await middleware.dispatch(mock_request, mock_call_next) + + # Verify call_next was called with the request + mock_call_next.assert_called_once_with(mock_request) + # Verify the response from call_next is returned + assert result == expected_response + + @pytest.mark.asyncio + async def test_dispatch_processes_middleware_when_should_skip_returns_false( + self, mock_request, mock_call_next + ): + """ + Test that dispatch processes middleware when should_skip returns False. + + This test verifies that: + 1. When should_skip returns False, process_request is called + 2. The middleware processing logic is executed + 3. The normal dispatch flow continues + """ + expected_response = Response(content="processed response", status_code=200) + mock_call_next.return_value = expected_response + + process_request_called = False + + class ProcessingMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + return False + + async def process_request(self, request: Request) -> Response | None: + nonlocal process_request_called + process_request_called = True + return None + + middleware = ProcessingMiddleware(app=Mock()) + result = await middleware.dispatch(mock_request, mock_call_next) + + # Verify process_request was called + assert process_request_called is True + # Verify call_next was called + mock_call_next.assert_called_once_with(mock_request) + # Verify the response from call_next is returned + assert result == expected_response + + def test_should_skip_receives_correct_request_parameter(self, mock_request): + """ + Test that should_skip method receives the correct request parameter. + + This test verifies that: + 1. The should_skip method receives the exact same request object + 2. The request parameter is passed correctly + """ + received_request = None + + class TestMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + nonlocal received_request + received_request = request + return False + + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = TestMiddleware(app=Mock()) + middleware.should_skip(mock_request) + + assert received_request == mock_request + + @pytest.mark.asyncio + async def test_dispatch_with_conditional_skip_logic(self, mock_request, mock_call_next): + """ + Test dispatch with conditional skip logic based on request properties. + + This test verifies that: + 1. should_skip can use request properties to make skip decisions + 2. The skip logic works correctly in the dispatch flow + 3. Both skip and process paths work as expected + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + class ConditionalMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + # Skip GET requests, process others + return request.method == "GET" + + async def process_request(self, request: Request) -> Response | None: + # This should only be called for non-GET requests + return Response(content="processed", status_code=202) + + middleware = ConditionalMiddleware(app=Mock()) + + # Test with GET request (should skip) + mock_request.method = "GET" + result = await middleware.dispatch(mock_request, mock_call_next) + + assert result == expected_response + mock_call_next.assert_called_once_with(mock_request) + + # Reset mock for next test + mock_call_next.reset_mock() + mock_call_next.return_value = expected_response + + # Test with POST request (should process) + mock_request.method = "POST" + result = await middleware.dispatch(mock_request, mock_call_next) + + assert result.body == b"processed" + assert result.status_code == 202 + mock_call_next.assert_not_called() + + @pytest.mark.asyncio + async def test_dispatch_with_url_based_skip_logic(self, mock_request, mock_call_next): + """ + Test dispatch with URL-based skip logic. + + This test verifies that: + 1. should_skip can use request URL to make skip decisions + 2. URL-based filtering works correctly + 3. The middleware processes only relevant requests + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + class URLBasedMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + # Skip requests to /health endpoint + return str(request.url).endswith("/health") + + async def process_request(self, request: Request) -> Response | None: + return Response(content="processed", status_code=202) + + middleware = URLBasedMiddleware(app=Mock()) + + # Test with health endpoint (should skip) + mock_request.url = "http://test.com/health" + result = await middleware.dispatch(mock_request, mock_call_next) + + assert result == expected_response + mock_call_next.assert_called_once_with(mock_request) + + # Reset mock for next test + mock_call_next.reset_mock() + mock_call_next.return_value = expected_response + + # Test with other endpoint (should process) + mock_request.url = "http://test.com/api/users" + result = await middleware.dispatch(mock_request, mock_call_next) + + assert result.body == b"processed" + assert result.status_code == 202 + mock_call_next.assert_not_called() + + def test_should_skip_method_signature(self): + """ + Test that should_skip method has the correct signature. + + This test verifies that: + 1. should_skip is an instance method + 2. It takes a Request parameter + 3. It returns a boolean value + """ + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + middleware = TestMiddleware(app=Mock()) + + # Check that should_skip is a method + assert hasattr(middleware, 'should_skip') + assert callable(middleware.should_skip) + + # Check that it's an instance method (not a class method or static method) + import inspect + sig = inspect.signature(middleware.should_skip) + params = list(sig.parameters.keys()) + + # Should have 'request' parameter (self is automatically handled by Python) + assert params == ['request'] + + # Check return type annotation + assert sig.return_annotation == bool + + +class TestMiddlewareRegistry: + """Test suite for the MiddlewareRegistry concrete class.""" + + @pytest.fixture + def fastapi_app(self): + """Fixture that provides a fresh FastAPI application instance.""" + return FastAPI() + + @pytest.fixture + def registry(self): + """Fixture that provides a fresh MiddlewareRegistry instance.""" + return MiddlewareRegistry() + + @pytest.fixture + def test_middleware_1(self): + """Fixture that provides a test middleware class.""" + class TestMiddleware1(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + return TestMiddleware1 + + @pytest.fixture + def test_middleware_2(self): + """Fixture that provides another test middleware class.""" + class TestMiddleware2(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + return TestMiddleware2 + + def test_middleware_registry_instantiation(self): + """ + Test that MiddlewareRegistry can be instantiated directly. + + This test verifies that: + 1. MiddlewareRegistry is a concrete class + 2. It can be instantiated without errors + 3. Initial state is correct + """ + registry = MiddlewareRegistry() + assert isinstance(registry, MiddlewareRegistry) + assert registry.get_middleware_count() == 0 + assert registry.get_middleware_classes() == [] + + def test_add_middleware(self, registry, test_middleware_1): + """ + Test adding middleware to the registry. + + This test verifies that: + 1. Middleware can be added successfully + 2. Middleware count increases + 3. Middleware appears in the classes list + """ + registry.add_middleware(test_middleware_1) + + assert registry.get_middleware_count() == 1 + assert registry.has_middleware(test_middleware_1) + assert test_middleware_1 in registry.get_middleware_classes() + + def test_add_duplicate_middleware_raises_error(self, registry, test_middleware_1): + """ + Test that adding duplicate middleware raises an error. + + This test verifies that: + 1. Adding the same middleware twice raises ValueError + 2. The error message is descriptive + 3. The registry state remains unchanged + """ + registry.add_middleware(test_middleware_1) + + with pytest.raises(ValueError, match="Middleware TestMiddleware1 is already registered"): + registry.add_middleware(test_middleware_1) + + # Verify state hasn't changed + assert registry.get_middleware_count() == 1 + + def test_add_at_index(self, registry, test_middleware_1, test_middleware_2): + """ + Test inserting middleware at specific index. + + This test verifies that: + 1. Middleware can be inserted at specific positions + 2. Order is maintained correctly + 3. Index bounds are respected + """ + registry.add_middleware(test_middleware_1) + registry.add_at_index(0, test_middleware_2) + + classes = registry.get_middleware_classes() + assert classes[0] == test_middleware_2 + assert classes[1] == test_middleware_1 + + def test_add_at_invalid_index_raises_error(self, registry, test_middleware_1): + """ + Test that adding at invalid index raises an error. + + This test verifies that: + 1. Invalid indices raise ValueError + 2. Error message includes valid range + """ + with pytest.raises(ValueError, match="Index -1 is out of range"): + registry.add_at_index(-1, test_middleware_1) + + with pytest.raises(ValueError, match="Index 1 is out of range"): + registry.add_at_index(1, test_middleware_1) + + def test_add_before(self, registry, test_middleware_1, test_middleware_2): + """ + Test inserting middleware before another middleware. + + This test verifies that: + 1. Middleware can be inserted before target middleware + 2. Order is correct after insertion + """ + registry.add_middleware(test_middleware_1) + registry.add_before(test_middleware_1, test_middleware_2) + + classes = registry.get_middleware_classes() + assert classes[0] == test_middleware_2 + assert classes[1] == test_middleware_1 + + def test_add_before_nonexistent_target_raises_error(self, registry, test_middleware_1, test_middleware_2): + """ + Test that adding before nonexistent target raises error. + + This test verifies that: + 1. Adding before non-registered middleware raises ValueError + 2. Error message is descriptive + """ + with pytest.raises(ValueError, match="Target middleware TestMiddleware1 not found"): + registry.add_before(test_middleware_1, test_middleware_2) + + def test_add_after(self, registry, test_middleware_1, test_middleware_2): + """ + Test inserting middleware after another middleware. + + This test verifies that: + 1. Middleware can be inserted after target middleware + 2. Order is correct after insertion + """ + registry.add_middleware(test_middleware_1) + registry.add_after(test_middleware_1, test_middleware_2) + + classes = registry.get_middleware_classes() + assert classes[0] == test_middleware_1 + assert classes[1] == test_middleware_2 + + def test_remove_middleware(self, registry, test_middleware_1): + """ + Test removing middleware from the registry. + + This test verifies that: + 1. Middleware can be removed successfully + 2. Middleware count decreases + 3. Middleware no longer appears in classes list + """ + registry.add_middleware(test_middleware_1) + registry.remove_middleware(test_middleware_1) + + assert registry.get_middleware_count() == 0 + assert not registry.has_middleware(test_middleware_1) + assert test_middleware_1 not in registry.get_middleware_classes() + + def test_remove_nonexistent_middleware_raises_error(self, registry, test_middleware_1): + """ + Test that removing nonexistent middleware raises error. + + This test verifies that: + 1. Removing non-registered middleware raises ValueError + 2. Error message is descriptive + """ + with pytest.raises(ValueError, match="Middleware TestMiddleware1 not found"): + registry.remove_middleware(test_middleware_1) + + def test_clear_middlewares(self, registry, test_middleware_1, test_middleware_2): + """ + Test clearing all middlewares from the registry. + + This test verifies that: + 1. All middlewares are removed + 2. Registry returns to initial state + """ + registry.add_middleware(test_middleware_1) + registry.add_middleware(test_middleware_2) + + registry.clear_middlewares() + + assert registry.get_middleware_count() == 0 + assert registry.get_middleware_classes() == [] + + def test_get_middleware_index(self, registry, test_middleware_1, test_middleware_2): + """ + Test getting the index of a middleware. + + This test verifies that: + 1. Index of registered middleware is returned correctly + 2. Index reflects the actual position in the list + """ + registry.add_middleware(test_middleware_1) + registry.add_middleware(test_middleware_2) + + assert registry.get_middleware_index(test_middleware_1) == 0 + assert registry.get_middleware_index(test_middleware_2) == 1 + + def test_get_middleware_index_nonexistent_raises_error(self, registry, test_middleware_1): + """ + Test that getting index of nonexistent middleware raises error. + + This test verifies that: + 1. Getting index of non-registered middleware raises ValueError + 2. Error message is descriptive + """ + with pytest.raises(ValueError, match="Middleware TestMiddleware1 not found"): + registry.get_middleware_index(test_middleware_1) + + def test_get_middleware_classes_returns_copy(self, registry, test_middleware_1): + """ + Test that get_middleware_classes returns a copy. + + This test verifies that: + 1. Modifying returned list doesn't affect internal state + 2. A copy is returned, not the original list + """ + registry.add_middleware(test_middleware_1) + + classes = registry.get_middleware_classes() + classes.clear() + + # Original registry should be unchanged + assert registry.get_middleware_count() == 1 + assert registry.has_middleware(test_middleware_1) + + def test_apply_middlewares_adds_middleware_to_app(self, registry, fastapi_app, test_middleware_1, test_middleware_2): + """ + Test that apply_middlewares correctly adds middleware classes to FastAPI app. + + This test verifies that: + 1. Middleware classes are added to the FastAPI application + 2. The add_middleware method is called for each middleware class + 3. The app is returned unchanged + """ + registry.add_middleware(test_middleware_1) + registry.add_middleware(test_middleware_2) + + # Mock the add_middleware method + with patch.object(fastapi_app, "add_middleware") as mock_add_middleware: + result = registry.apply_middlewares(fastapi_app) + + # Verify add_middleware was called for each middleware class + assert mock_add_middleware.call_count == 2 + mock_add_middleware.assert_any_call(test_middleware_1) + mock_add_middleware.assert_any_call(test_middleware_2) + + # Verify the app is returned + assert result == fastapi_app + + def test_apply_middlewares_with_empty_list(self, registry, fastapi_app): + """ + Test that apply_middlewares handles empty middleware list correctly. + + This test verifies that: + 1. When no middlewares are registered, no middleware is added + 2. The app is returned unchanged + 3. No errors occur with empty middleware list + """ + with patch.object(fastapi_app, "add_middleware") as mock_add_middleware: + result = registry.apply_middlewares(fastapi_app) + + # Verify add_middleware was not called + mock_add_middleware.assert_not_called() + + # Verify the app is returned + assert result == fastapi_app + + def test_apply_middlewares_preserves_app_state(self, registry, fastapi_app, test_middleware_1): + """ + Test that apply_middlewares preserves the FastAPI app state. + + This test verifies that: + 1. The original app object is returned (same reference) + 2. No app properties are modified during middleware application + """ + registry.add_middleware(test_middleware_1) + + # Store original app state + original_app_id = id(fastapi_app) + + result = registry.apply_middlewares(fastapi_app) + + # Verify same app object is returned + assert id(result) == original_app_id + assert result is fastapi_app + + +class TestMiddlewareConfiguration: + """Test suite for the MiddlewareConfiguration class.""" + + def test_middleware_configuration_inheritance(self): + """ + Test that MiddlewareConfiguration has proper inheritance. + + This test verifies that: + 1. MiddlewareConfiguration inherits from SingleInheritanceRequired + 2. It can be instantiated + """ + class TestConfig(MiddlewareConfiguration): + pass + + config = TestConfig() + assert isinstance(config, MiddlewareConfiguration) + + def test_setup_middlewares_can_be_overridden(self): + """ + Test that setup_middlewares can be overridden to configure middlewares. + + This test verifies that: + 1. setup_middlewares can be overridden + 2. The registry is properly configured when overridden + """ + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestConfig(MiddlewareConfiguration): + def setup_middlewares(self, registry: MiddlewareRegistry) -> None: + registry.add_middleware(TestMiddleware) + + config = TestConfig() + registry = MiddlewareRegistry() + + config.setup_middlewares(registry) + + assert registry.has_middleware(TestMiddleware) + assert registry.get_middleware_count() == 1 + + +class TestMiddlewareIntegration: + """Integration tests for middleware functionality.""" + + @pytest.fixture + def fastapi_app(self): + """Fixture that provides a fresh FastAPI application instance.""" + return FastAPI() + + @pytest.mark.asyncio + async def test_middleware_chain_execution(self, fastapi_app): + """ + Test that multiple middlewares execute in the correct order. + + This test verifies that: + 1. Middlewares are executed in the order they are added + 2. Each middleware can process the request + 3. The chain continues correctly when middlewares return None + """ + execution_order = [] + + class FirstMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("first") + return None + + class SecondMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("second") + return None + + registry = MiddlewareRegistry() + registry.add_middleware(FirstMiddleware) + registry.add_middleware(SecondMiddleware) + + app = registry.apply_middlewares(fastapi_app) + + # Create a test client to trigger middleware execution + from fastapi.testclient import TestClient + + @app.get("/test") + async def test_endpoint(): + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + # Verify middlewares were executed in order (FastAPI uses LIFO - Last In, First Out) + assert execution_order == ["second", "first"] + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_middleware_early_return(self, fastapi_app): + """ + Test that middleware can return early and prevent further execution. + + This test verifies that: + 1. When a middleware returns a response, subsequent middlewares are not executed + 2. The route handler is not called when middleware returns early + 3. The response from the middleware is returned to the client + """ + execution_order = [] + + class BlockingMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("blocking") + return Response(content="blocked", status_code=403) + + class SecondMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("second") + return None + + registry = MiddlewareRegistry() + registry.add_middleware(BlockingMiddleware) + registry.add_middleware(SecondMiddleware) + + app = registry.apply_middlewares(fastapi_app) + + @app.get("/test") + async def test_endpoint(): + execution_order.append("handler") + return {"message": "test"} + + client = TestClient(app) + response = client.get("/test") + + # Verify only blocking middleware executed (FastAPI uses LIFO - Last In, First Out) + # SecondMiddleware executes first, then BlockingMiddleware returns early + assert execution_order == ["second", "blocking"] + assert response.status_code == 403 + assert response.text == "blocked" + + def test_middleware_registry_with_configuration(self): + """ + Test using MiddlewareRegistry with MiddlewareConfiguration. + + This test verifies that: + 1. MiddlewareConfiguration can configure a MiddlewareRegistry + 2. The configuration is applied correctly + """ + class TestMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + return None + + class TestConfig(MiddlewareConfiguration): + def setup_middlewares(self, registry: MiddlewareRegistry) -> None: + registry.add_middleware(TestMiddleware) + + config = TestConfig() + registry = MiddlewareRegistry() + + config.setup_middlewares(registry) + + assert registry.has_middleware(TestMiddleware) + assert registry.get_middleware_count() == 1 + + def test_middleware_execution_order_with_skip_logic(self, fastapi_app): + """ + Test middleware execution order with skip logic. + + This test verifies that: + 1. Middlewares with skip logic are handled correctly + 2. Order is maintained even when some middlewares skip + """ + execution_order = [] + + class ConditionalMiddleware(Middleware): + def should_skip(self, request: Request) -> bool: + return "/skip" in str(request.url) + + async def process_request(self, request: Request) -> Response | None: + execution_order.append("conditional") + return None + + class AlwaysRunMiddleware(Middleware): + async def process_request(self, request: Request) -> Response | None: + execution_order.append("always") + return None + + registry = MiddlewareRegistry() + registry.add_middleware(ConditionalMiddleware) + registry.add_middleware(AlwaysRunMiddleware) + + app = registry.apply_middlewares(fastapi_app) + + @app.get("/test") + async def test_endpoint(): + return {"message": "test"} + + @app.get("/skip") + async def skip_endpoint(): + return {"message": "skip"} + + client = TestClient(app) + + # Test normal endpoint + execution_order.clear() + response = client.get("/test") + assert execution_order == ["always", "conditional"] + assert response.status_code == 200 + + # Test skip endpoint + execution_order.clear() + response = client.get("/skip") + assert execution_order == ["always"] # Only AlwaysRunMiddleware should execute + assert response.status_code == 200 diff --git a/tests/test_middleware_single_execution.py b/tests/test_middleware_single_execution.py new file mode 100644 index 0000000..5ab7697 --- /dev/null +++ b/tests/test_middleware_single_execution.py @@ -0,0 +1,291 @@ +import pytest +from unittest.mock import AsyncMock, Mock +from fastapi import FastAPI, Request, Response +from fastapi.testclient import TestClient +from collections import defaultdict +import threading +import time + +from py_spring_core.core.entities.middlewares.middleware import Middleware +from py_spring_core.core.entities.middlewares.middleware_registry import MiddlewareRegistry + + +class SingleExecutionMiddleware(Middleware): + """ + Test middleware that ensures it only processes each request once + by tracking processed requests and preventing duplicate processing + """ + + def __init__(self, app): + super().__init__(app) + self.processed_requests = set() + self.execution_count = defaultdict(int) + self._lock = threading.Lock() + + async def process_request(self, request: Request) -> Response | None: + # Create a unique identifier for this request + request_id = id(request) + + # Use thread lock to ensure thread safety + with self._lock: + # Check if this request has already been processed + if request_id in self.processed_requests: + # This request has already been processed, skip it + return None + + # Mark this request as processed + self.processed_requests.add(request_id) + self.execution_count[request_id] += 1 + + # Return None to continue processing + return None + + +class TestMiddlewareSingleExecution: + """Test suite for ensuring middleware executes only once per request.""" + + @pytest.fixture + def mock_request(self): + """Fixture that provides a mock FastAPI request.""" + request = Mock(spec=Request) + request.method = "GET" + request.url = "http://test.com/api" + return request + + @pytest.fixture + def mock_call_next(self): + """Fixture that provides a mock call_next function.""" + return AsyncMock() + + @pytest.mark.asyncio + async def test_middleware_executes_only_once_per_request(self, mock_request, mock_call_next): + """ + Test that middleware executes only once per request. + + This test verifies that: + 1. When dispatch is called multiple times with the same request, + process_request logic is only executed once + 2. The execution count for the request remains at 1 + 3. The request is only processed once + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + middleware = SingleExecutionMiddleware(app=Mock()) + + # Call dispatch multiple times with the same request + for _ in range(3): + result = await middleware.dispatch(mock_request, mock_call_next) + + # Verify that process_request logic was only executed once for this request + request_id = id(mock_request) + assert middleware.execution_count[request_id] == 1 + + # Verify that the request was added to processed set + assert request_id in middleware.processed_requests + + # Verify call_next was called the expected number of times + assert mock_call_next.call_count == 3 + + # Verify the response is correct + assert result == expected_response + + @pytest.mark.asyncio + async def test_middleware_executes_once_per_different_request(self, mock_call_next): + """ + Test that middleware executes once for each different request. + + This test verifies that: + 1. Each unique request is processed exactly once + 2. Different requests are tracked separately + 3. The execution count is correct for each request + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + middleware = SingleExecutionMiddleware(app=Mock()) + + # Create multiple different requests + requests = [] + for i in range(3): + request = Mock(spec=Request) + request.method = "GET" + request.url = f"http://test.com/api/{i}" + requests.append(request) + + # Process each request once + for request in requests: + await middleware.dispatch(request, mock_call_next) + + # Verify each request was processed exactly once + for request in requests: + request_id = id(request) + assert middleware.execution_count[request_id] == 1 + assert request_id in middleware.processed_requests + + # Verify total number of processed requests + assert len(middleware.processed_requests) == 3 + assert sum(middleware.execution_count.values()) == 3 + + @pytest.mark.asyncio + async def test_middleware_handles_early_return_correctly(self, mock_request, mock_call_next): + """ + Test that middleware handles early return correctly while maintaining single execution. + + This test verifies that: + 1. When middleware returns a response early, it still counts as executed once + 2. The request is tracked even when middleware returns early + 3. call_next is not called when middleware returns early + """ + middleware_response = Response(content="middleware response", status_code=403) + + class EarlyReturnMiddleware(SingleExecutionMiddleware): + async def process_request(self, request: Request) -> Response | None: + # Call parent to track execution + await super().process_request(request) + # Return early response + return middleware_response + + middleware = EarlyReturnMiddleware(app=Mock()) + + # Call dispatch + result = await middleware.dispatch(mock_request, mock_call_next) + + # Verify that process_request was called once + request_id = id(mock_request) + assert middleware.execution_count[request_id] == 1 + assert request_id in middleware.processed_requests + + # Verify call_next was not called (early return) + mock_call_next.assert_not_called() + + # Verify the early response is returned + assert result == middleware_response + + @pytest.mark.asyncio + async def test_middleware_with_skip_logic_maintains_single_execution(self, mock_request, mock_call_next): + """ + Test that middleware with skip logic maintains single execution tracking. + + This test verifies that: + 1. When should_skip returns True, the request is still tracked + 2. The execution count remains at 0 for skipped requests + 3. The request is still added to the processed set + """ + expected_response = Response(content="skipped response", status_code=200) + mock_call_next.return_value = expected_response + + class SkippingMiddleware(SingleExecutionMiddleware): + def should_skip(self, request: Request) -> bool: + return True + + async def process_request(self, request: Request) -> Response | None: + # This should never be called when should_skip returns True + raise AssertionError("process_request should not be called when should_skip returns True") + + middleware = SkippingMiddleware(app=Mock()) + + # Call dispatch + result = await middleware.dispatch(mock_request, mock_call_next) + + # Verify that process_request was not called (due to skip) + request_id = id(mock_request) + assert request_id not in middleware.processed_requests + + # Verify call_next was called + mock_call_next.assert_called_once_with(mock_request) + assert result == expected_response + + @pytest.mark.asyncio + async def test_middleware_thread_safety(self, mock_call_next): + """ + Test that middleware is thread-safe when processing multiple requests concurrently. + + This test verifies that: + 1. Multiple threads can safely access the middleware + 2. Each request is processed exactly once even under concurrent access + 3. No race conditions occur + """ + expected_response = Response(content="test response", status_code=200) + mock_call_next.return_value = expected_response + + middleware = SingleExecutionMiddleware(app=Mock()) + + # Create multiple requests + requests = [] + for i in range(10): + request = Mock(spec=Request) + request.method = "GET" + request.url = f"http://test.com/api/{i}" + requests.append(request) + + # Process requests concurrently + import asyncio + + async def process_request(request): + return await middleware.dispatch(request, mock_call_next) + + # Process all requests concurrently + results = await asyncio.gather(*[process_request(req) for req in requests]) + + # Verify each request was processed exactly once + for request in requests: + request_id = id(request) + assert middleware.execution_count[request_id] == 1 + assert request_id in middleware.processed_requests + + # Verify all responses are correct + for result in results: + assert result == expected_response + + # Verify total number of processed requests + assert len(middleware.processed_requests) == 10 + assert sum(middleware.execution_count.values()) == 10 + + @pytest.mark.asyncio + async def test_middleware_integration_with_fastapi(self): + """ + Test middleware single execution in a real FastAPI application. + + This test verifies that: + 1. Middleware executes only once per request in a real FastAPI app + 2. Multiple requests are handled correctly + 3. The tracking works across different endpoints + """ + app = FastAPI() + + class TestMiddleware(SingleExecutionMiddleware): + pass + + class TestRegistry(MiddlewareRegistry): + def get_middleware_classes(self) -> list[type[Middleware]]: + return [TestMiddleware] + + # Apply middleware to app + registry = TestRegistry() + app = registry.apply_middlewares(app) + + # Create test endpoints + @app.get("/test1") + async def test_endpoint1(): + return {"message": "test1"} + + @app.get("/test2") + async def test_endpoint2(): + return {"message": "test2"} + + # Create test client + client = TestClient(app) + + # Make multiple requests + response1 = client.get("/test1") + response2 = client.get("/test2") + response3 = client.get("/test1") # Same endpoint as first request + + # Verify responses + assert response1.status_code == 200 + assert response2.status_code == 200 + assert response3.status_code == 200 + + # The key point is that the middleware doesn't interfere with normal operation + # and each request is processed exactly once through the middleware chain \ No newline at end of file diff --git a/tests/test_module_import_cache.py b/tests/test_module_import_cache.py new file mode 100644 index 0000000..7a10127 --- /dev/null +++ b/tests/test_module_import_cache.py @@ -0,0 +1,166 @@ +import tempfile +import os +from pathlib import Path +from unittest.mock import patch + + +from py_spring_core.core.utils import dynamically_import_modules, clear_module_cache +from py_spring_core.commons.class_scanner import ClassScanner + + +class TestModuleImportCache: + """Test module import caching to prevent duplicate imports.""" + + def setup_method(self): + """Clear module cache before each test.""" + clear_module_cache() + + def test_dynamically_import_modules_cache(self): + """Test that dynamically_import_modules uses cache to prevent duplicate imports.""" + # Create a temporary Python file with a class + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + def __init__(self): + self.value = "test" +""") + temp_file_path = f.name + + try: + # First import - get all classes without filtering + classes1 = dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Second import of the same file + classes2 = dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Both should return the same classes + assert len(classes1) == 1 + assert len(classes2) == 1 + assert classes1 == classes2 + + # The class should be the same object (not a duplicate) + class1 = list(classes1)[0] + class2 = list(classes2)[0] + assert class1 is class2 + + finally: + os.unlink(temp_file_path) + + def test_class_scanner_cache(self): + """Test that ClassScanner uses cache to prevent duplicate imports.""" + # Create a temporary Python file with a class + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + def __init__(self): + self.value = "test" +""") + temp_file_path = f.name + + try: + scanner = ClassScanner([temp_file_path]) + + # First scan + scanner.scan_classes_for_file_paths([]) + classes1 = list(scanner.get_classes()) + + # Second scan + scanner.scan_classes_for_file_paths([]) + classes2 = list(scanner.get_classes()) + + # Both should return the same classes + assert len(classes1) == 1 + assert len(classes2) == 1 + assert classes1 == classes2 + + # The class should be the same object (not a duplicate) + assert classes1[0] is classes2[0] + + finally: + os.unlink(temp_file_path) + + def test_clear_module_cache(self): + """Test that clear_module_cache works correctly.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + pass +""") + temp_file_path = f.name + + try: + # First import + classes1 = dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Clear cache + clear_module_cache() + + # Second import after clearing cache + classes2 = dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Classes should be different objects after clearing cache + class1 = list(classes1)[0] + class2 = list(classes2)[0] + assert class1 is not class2 + + finally: + os.unlink(temp_file_path) + + def test_class_scanner_clear_cache(self): + """Test that ClassScanner clear_module_cache works correctly.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + pass +""") + temp_file_path = f.name + + try: + scanner = ClassScanner([temp_file_path]) + + # First scan + scanner.scan_classes_for_file_paths([]) + classes1 = list(scanner.get_classes()) + + # Clear cache + scanner.clear_module_cache() + + # Second scan after clearing cache + scanner.scan_classes_for_file_paths([]) + classes2 = list(scanner.get_classes()) + + # Classes should be different objects after clearing cache + assert classes1[0] is not classes2[0] + + finally: + os.unlink(temp_file_path) + + @patch('py_spring_core.commons.module_importer.logger') + def test_cache_logging(self, mock_logger): + """Test that cache usage is properly logged.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + pass +""") + temp_file_path = f.name + + try: + # First import (should log import) + dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Second import (should log cache usage) + dynamically_import_modules([temp_file_path], target_subclasses=[]) + + # Check that debug log for cache usage was called + # Get the actual module name from the temp file + module_name = Path(temp_file_path).stem + mock_logger.debug.assert_called_with( + f"[MODULE CACHE] Using cached module: {module_name}" + ) + + finally: + os.unlink(temp_file_path) \ No newline at end of file diff --git a/tests/test_module_importer.py b/tests/test_module_importer.py new file mode 100644 index 0000000..8e50ac7 --- /dev/null +++ b/tests/test_module_importer.py @@ -0,0 +1,227 @@ +import tempfile +import os +from pathlib import Path +from typing import Type + +import pytest + +from py_spring_core.commons.module_importer import ModuleImporter + + +class TestModuleImporter: + """Test ModuleImporter class functionality.""" + + def setup_method(self): + """Clear module cache before each test.""" + self.importer = ModuleImporter() + + def test_import_module_from_path(self): + """Test importing a module from a file path.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + def __init__(self): + self.value = "test" + +def test_function(): + return "test" +""") + temp_file_path = f.name + + try: + module = self.importer.import_module_from_path(temp_file_path) + + assert module is not None + assert hasattr(module, 'TestClass') + assert hasattr(module, 'test_function') + + # Test that the class can be instantiated + test_instance = module.TestClass() + assert test_instance.value == "test" + + finally: + os.unlink(temp_file_path) + + def test_import_module_from_path_caching(self): + """Test that modules are cached and reused.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + pass +""") + temp_file_path = f.name + + try: + # First import + module1 = self.importer.import_module_from_path(temp_file_path) + cache_size1 = self.importer.get_cache_size() + + # Second import + module2 = self.importer.import_module_from_path(temp_file_path) + cache_size2 = self.importer.get_cache_size() + + # Should be the same module object + assert module1 is module2 + # Cache size should not increase + assert cache_size1 == cache_size2 + + finally: + os.unlink(temp_file_path) + + def test_extract_classes_from_module(self): + """Test extracting classes from a module.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class ClassA: + pass + +class ClassB: + pass + +def function(): + pass + +CONSTANT = 42 +""") + temp_file_path = f.name + + try: + module = self.importer.import_module_from_path(temp_file_path) + classes = self.importer.extract_classes_from_module(module) + + # Should find 2 classes + assert len(classes) == 2 + + # Check class names + class_names = [cls.__name__ for cls in classes] + assert 'ClassA' in class_names + assert 'ClassB' in class_names + + finally: + os.unlink(temp_file_path) + + def test_import_classes_from_paths(self): + """Test importing classes from multiple file paths.""" + # Create temporary Python files + temp_files = [] + try: + for i in range(2): + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(f""" +class Class{i}: + def __init__(self): + self.id = {i} +""") + temp_files.append(f.name) + + # Import classes from both files + classes = self.importer.import_classes_from_paths(temp_files) + + # Should find 2 classes + assert len(classes) == 2 + + # Check that classes can be instantiated + class_list = list(classes) + instance0 = class_list[0]() + instance1 = class_list[1]() + + assert hasattr(instance0, 'id') + assert hasattr(instance1, 'id') + + finally: + for temp_file in temp_files: + os.unlink(temp_file) + + def test_import_classes_from_paths_with_filtering(self): + """Test importing classes with target subclass filtering.""" + # Create a temporary Python file with different types of classes + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class BaseClass: + pass + +class SubClass(BaseClass): + pass + +class UnrelatedClass: + pass +""") + temp_file_path = f.name + + try: + # First import to get the BaseClass + all_classes = self.importer.import_classes_from_paths([temp_file_path]) + base_class = None + for cls in all_classes: + if cls.__name__ == 'BaseClass': + base_class = cls + break + + assert base_class is not None + + # Import only subclasses of BaseClass + classes = self.importer.import_classes_from_paths( + [temp_file_path], + target_subclasses=[base_class] + ) + + # Should only find SubClass + assert len(classes) == 1 + class_list = list(classes) + assert class_list[0].__name__ == 'SubClass' + + finally: + os.unlink(temp_file_path) + + def test_clear_cache(self): + """Test clearing the module cache.""" + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + pass +""") + temp_file_path = f.name + + try: + # Import module + module1 = self.importer.import_module_from_path(temp_file_path) + assert self.importer.get_cache_size() == 1 + + # Clear cache + self.importer.clear_cache() + assert self.importer.get_cache_size() == 0 + + # Import again (should be a different object) + module2 = self.importer.import_module_from_path(temp_file_path) + assert module1 is not module2 + + finally: + os.unlink(temp_file_path) + + def test_import_nonexistent_file(self): + """Test importing a non-existent file.""" + with pytest.raises(FileNotFoundError): + self.importer.import_module_from_path("/nonexistent/file.py") + + def test_import_invalid_python_file(self): + """Test importing a file with invalid Python syntax.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write(""" +class TestClass: + def __init__(self): + # Invalid syntax - missing colon + if True + pass +""") + temp_file_path = f.name + + try: + with pytest.raises(SyntaxError): + self.importer.import_module_from_path(temp_file_path) + + finally: + os.unlink(temp_file_path) \ No newline at end of file diff --git a/tests/test_rest_controller.py b/tests/test_rest_controller.py deleted file mode 100644 index c2c2669..0000000 --- a/tests/test_rest_controller.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest -from fastapi import APIRouter, FastAPI - -from py_spring_core.core.entities.controllers.rest_controller import RestController - - -class TestRestController: - @pytest.fixture - def app(self) -> FastAPI: - return FastAPI() - - @pytest.fixture - def router(self) -> APIRouter: - return APIRouter() - - @pytest.fixture - def test_controller(self, app: FastAPI, router: APIRouter) -> RestController: - class TestController(RestController): - def register_routes(self): - self.router.add_api_route("/test", lambda: "test") - - TestController.app = app - TestController.router = router - - return TestController() - - def test_register_routes_successfully(self, test_controller: RestController): - test_controller.register_routes() - assert len(test_controller.router.routes) == 1 diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index b11b749..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,112 +0,0 @@ -import pytest -from py_spring_core.core.utils import ( - check_type_hints_for_callable, - TypeHintError, - check_type_hints_for_class, -) - - -class TestCheckTypeHintsForCallable: - def test_function_with_argument_type_hints_but_no_return_type(self): - def test_func(a: int, b: str): - pass # No return type hint - - with pytest.raises( - TypeHintError, - match="Type hints for 'return type' not provided for the function", - ): - check_type_hints_for_callable(test_func) - - def test_function_with_no_arguments(self): - def test_func() -> None: - pass - - # No exception should be raised for a function with no arguments - check_type_hints_for_callable(test_func) - - def test_function_with_correct_type_hints(self): - def test_func(a: int, b: str) -> None: - pass - - # No exception should be raised - check_type_hints_for_callable(test_func) - - def test_function_with_no_type_hints(self): - def test_func(a, b) -> None: - pass - - with pytest.raises(TypeHintError, match="Type hints not fully provided"): - check_type_hints_for_callable(test_func) - - def test_function_with_mismatched_type_hints(self): - def test_func(a: int, b, c) -> None: - pass - - # Intentionally using only one type hint - - with pytest.raises(TypeHintError, match="Type hints not fully provided"): - check_type_hints_for_callable(test_func) - - -class TestClassFullyTyped: - def method(self, a: int, b: str) -> bool: - return True - - -class TestClassNotReturnTyped: - def method(self, a: int, b: str): - pass # No return type hint - - -class TestClassWithNoArgs: - def method(self) -> bool: - return True - - -class TestClassNotFullyTyped: - def method(self, a, b: str) -> bool: - return True - - -class TestClassWithNotArsAndReturnTyped: - def method(self): - pass # No type hints at all - - -class TestClassWithVariableInside: - def method(self) -> None: - test = "" - pass # No type hints at all - - -class TestCheckTypeHintsForClass: - def test_class_with_method_and_return_type_hints(self): - # No exception should be raised - check_type_hints_for_class(TestClassFullyTyped) - - def test_class_with_method_missing_return_type_hint(self): - with pytest.raises( - TypeHintError, - match="Type hints for 'return type' not provided for the function", - ): - check_type_hints_for_class(TestClassNotReturnTyped) - - def test_class_with_method_missing_argument_type_hint(self): - with pytest.raises(TypeHintError, match="Type hints not fully provided"): - check_type_hints_for_class(TestClassNotFullyTyped) - - def test_class_with_method_having_no_arguments_but_return_type_hint(self): - # No exception should be raised - check_type_hints_for_class(TestClassWithNoArgs) - - def test_class_with_method_having_no_type_hints(self): - with pytest.raises( - TypeHintError, - match="Type hints for 'return type' not provided for the function", - ): - check_type_hints_for_class(TestClassWithNotArsAndReturnTyped) - - def test_class_with_method_having_variable_inside(self): - check_type_hints_for_class(TestClassWithVariableInside) - -