diff --git a/py_spring_model/__init__.py b/py_spring_model/__init__.py index 64c7753..5571a52 100644 --- a/py_spring_model/__init__.py +++ b/py_spring_model/__init__.py @@ -1,6 +1,21 @@ from py_spring_model.core.model import PySpringModel, Field +from py_spring_model.core.session_context_holder import SessionContextHolder from py_spring_model.py_spring_model_provider import provide_py_spring_model from py_spring_model.repository.crud_repository import CrudRepository from py_spring_model.repository.repository_base import RepositoryBase from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.crud_repository_implementation_service import SkipAutoImplmentation -from py_spring_model.py_spring_model_rest.service.query_service.query import Query \ No newline at end of file +from py_spring_model.py_spring_model_rest.service.query_service.query import Query + + +__all__ = [ + "PySpringModel", + "Field", + "SessionContextHolder", + "provide_py_spring_model", + "CrudRepository", + "RepositoryBase", + "SkipAutoImplmentation", + "Query", +] + +__version__ = "0.1.0" \ No newline at end of file diff --git a/py_spring_model/core/model.py b/py_spring_model/core/model.py index 9f7987f..d5166d2 100644 --- a/py_spring_model/core/model.py +++ b/py_spring_model/core/model.py @@ -1,6 +1,5 @@ import contextlib -from typing_extensions import Self -from typing import ClassVar, Iterator, Optional, Type +from typing import ClassVar, Iterator, Optional, Type, TypeVar from loguru import logger from sqlalchemy import Engine, MetaData from sqlalchemy.engine.base import Connection @@ -8,6 +7,8 @@ from py_spring_model.core.py_spring_session import PySpringSession +T = TypeVar("T", bound="PySpringModel") + class PySpringModel(SQLModel): """ @@ -81,7 +82,7 @@ def get_model_lookup(cls) -> dict[str, type["PySpringModel"]]: raise ValueError("[MODEL_LOOKUP NOT SET] Model lookup is not set") return {str(_model.__tablename__): _model for _model in cls._models} - def clone(self) -> Self: + def clone(self: T) -> T: return self.model_validate_json(self.model_dump_json()) @classmethod @@ -105,11 +106,7 @@ def create_managed_session(cls, should_commit: bool = True) -> Iterator[PySpring logger.debug("[MANAGED SESSION COMMIT] Session committing...") if should_commit: session.commit() - logger.debug( - "[MANAGED SESSION COMMIT] Session committed, refreshing instances..." - ) - session.refresh_current_session_instances() - logger.success("[MANAGED SESSION COMMIT] Session committed.") + logger.success("[MANAGED SESSION COMMIT] Session committed.") except Exception as error: logger.error(error) logger.error("[MANAGED SESSION ROLLBACK] Session rolling back...") diff --git a/py_spring_model/core/py_spring_session.py b/py_spring_model/core/py_spring_session.py index 7a97dfe..7080445 100644 --- a/py_spring_model/core/py_spring_session.py +++ b/py_spring_model/core/py_spring_session.py @@ -1,5 +1,6 @@ from typing import Iterable +from loguru import logger from sqlmodel import Session, SQLModel @@ -33,3 +34,12 @@ def add_all(self, instances: Iterable[SQLModel]) -> None: def refresh_current_session_instances(self) -> None: for instance in self.current_session_instance: self.refresh(instance) + + def commit(self) -> None: + # Import here to avoid circular import + from py_spring_model.core.session_context_holder import SessionContextHolder + if SessionContextHolder.is_transaction_managed(): + logger.warning("Commiting a transaction that is currently being managed by the outermost transaction is strongly discouraged...") + return + super().commit() + self.refresh_current_session_instances() \ No newline at end of file diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py new file mode 100644 index 0000000..1f7eafc --- /dev/null +++ b/py_spring_model/core/session_context_holder.py @@ -0,0 +1,141 @@ +from contextvars import ContextVar +from enum import IntEnum +from functools import wraps +from typing import Any, Callable, ClassVar, Optional, ParamSpec, TypeVar + +from py_spring_model.core.model import PySpringModel +from py_spring_model.core.py_spring_session import PySpringSession + +class TransactionalDepth(IntEnum): + OUTERMOST = 1 + ON_EXIT = 0 + + +P = ParamSpec("P") +RT = TypeVar("RT") + +def Transactional(func: Callable[P, RT]) -> Callable[P, RT]: + """ + Decorator for managing database transactions in a nested-safe manner. + + This decorator ensures that: + - A new session is created only if there is no active session (i.e., outermost transaction). + - The session is committed, rolled back, and closed only by the outermost function. + - Nested transactional functions share the same session and do not interfere with the commit/rollback behavior. + + Behavior Summary: + - If this function is the outermost @Transactional in the call stack: + - A new session is created. + - On success, the session is committed. + - On failure, the session is rolled back. + - The session is closed after execution. + - If this function is called within an existing transaction: + - The existing session is reused. + - No commit, rollback, or close is performed (delegated to the outermost function). + + Example: + @Transactional + def outer_operation(): + create_user() + update_account() + + @Transactional + def create_user(): + db.session.add(User(...)) # Uses same session as outer_operation + + @Transactional + def update_account(): + db.session.add(Account(...)) # Uses same session as outer_operation + + Only outer_operation will commit or rollback. + If create_user() or update_account() raises an exception, + the whole transaction will be rolled back. + """ + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> RT: + # Increment session depth and get session + session_depth = SessionContextHolder.enter_session() + session = SessionContextHolder.get_or_create_session() + try: + result = func(*args, **kwargs) + # Only commit at the outermost level (session_depth == 1) + if session_depth == TransactionalDepth.OUTERMOST.value: + session.commit() + return result + except Exception as error: + # Only rollback at the outermost level (session_depth == 1) + if session_depth == TransactionalDepth.OUTERMOST.value: + session.rollback() + raise error + finally: + # Decrement depth and clean up session if needed + SessionContextHolder.exit_session() + return wrapper + +class SessionContextHolder: + """ + A context holder for the session with explicit depth tracking. + This is used to store the session in a context variable so that it can be accessed by the query service. + The depth counter ensures that only the outermost transaction manages commit/rollback operations. + """ + _session: ClassVar[ContextVar[Optional[PySpringSession]]] = ContextVar("session", default=None) + _session_depth: ClassVar[ContextVar[int]] = ContextVar("session_depth", default=0) + + @classmethod + def get_or_create_session(cls) -> PySpringSession: + optional_session = cls._session.get() + if optional_session is None: + session = PySpringModel.create_session() + cls._session.set(session) + return session + return optional_session + + @classmethod + def has_session(cls) -> bool: + return cls._session.get() is not None + + @classmethod + def get_session_depth(cls) -> int: + """Get the current session depth.""" + return cls._session_depth.get() + + @classmethod + def enter_session(cls) -> int: + """ + Enter a new session context and increment the depth counter. + Returns the new depth level. + """ + current_depth = cls._session_depth.get() + new_depth = current_depth + 1 + cls._session_depth.set(new_depth) + return new_depth + + @classmethod + def exit_session(cls) -> int: + """ + Exit the current session context and decrement the depth counter. + If depth reaches 0, clear the session. + Returns the new depth level. + """ + current_depth = cls._session_depth.get() + new_depth = max(0, current_depth - 1) # Prevent negative depth + cls._session_depth.set(new_depth) + + # Clear session only when depth reaches 0 (outermost level) + if new_depth == TransactionalDepth.ON_EXIT.value: + cls.clear_session() + + return new_depth + + @classmethod + def clear_session(cls): + """Clear the session and reset depth to 0.""" + session = cls._session.get() + if session is not None: + session.close() + cls._session.set(None) + cls._session_depth.set(TransactionalDepth.ON_EXIT.value) + + @classmethod + def is_transaction_managed(cls) -> bool: + return cls._session_depth.get() > TransactionalDepth.OUTERMOST.value \ No newline at end of file diff --git a/py_spring_model/py_spring_model_provider.py b/py_spring_model/py_spring_model_provider.py index 9c31d33..2c874f2 100644 --- a/py_spring_model/py_spring_model_provider.py +++ b/py_spring_model/py_spring_model_provider.py @@ -10,6 +10,7 @@ from py_spring_model.core.commons import ApplicationFileGroups, PySpringModelProperties from py_spring_model.core.model import PySpringModel +from py_spring_model.py_spring_model_rest.controller.session_controller import SessionController from py_spring_model.repository.repository_base import RepositoryBase from py_spring_model.py_spring_model_rest import PySpringModelRestService from py_spring_model.py_spring_model_rest.controller.py_spring_model_rest_controller import ( @@ -149,6 +150,7 @@ def provide_py_spring_model() -> EntityProvider: return PySpringModelProvider( rest_controller_classes=[ # PySpringModelRestController + SessionController ], component_classes=[ PySpringModelRestService, diff --git a/py_spring_model/py_spring_model_rest/controller/session_controller.py b/py_spring_model/py_spring_model_rest/controller/session_controller.py new file mode 100644 index 0000000..925cd2f --- /dev/null +++ b/py_spring_model/py_spring_model_rest/controller/session_controller.py @@ -0,0 +1,38 @@ +from typing import Awaitable, Callable +from fastapi import Request, Response +from py_spring_model.core.session_context_holder import SessionContextHolder +from py_spring_core import RestController + + +async def session_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + """ + Middleware to ensure that the database session is properly cleaned up after each HTTP request. + + This middleware works with context-based session management using ContextVar. + It guarantees that each request has its own isolated database session context, and + that any session stored in the context is properly closed after the request is handled. + + It does NOT create or commit any transactions by itself. It is meant to be used + in combination with a decorator-based transaction manager (e.g., @Transactional) + which controls when to commit or rollback. + + This middleware acts as a safety net: + - Ensures that the ContextVar-based session is cleared after each request. + - Prevents session leakage between requests in case of unexpected exceptions or unhandled paths. + - Complements nested transaction logic by guaranteeing session cleanup at request boundaries. + + Use this middleware when: + - You are using context-local session handling (via contextvars). + - You want to ensure that long-lived or leaked sessions don't accumulate. + """ + try: + response = await call_next(request) + return response + finally: + SessionContextHolder.clear_session() + +class SessionController(RestController): + def post_construct(self) -> None: + self.app.middleware("http")(session_middleware) + + \ No newline at end of file diff --git a/py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py b/py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py index efcfd65..45a03d4 100644 --- a/py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py +++ b/py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py @@ -19,6 +19,7 @@ from sqlmodel.sql.expression import SelectOfScalar from py_spring_model.core.model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional from py_spring_model.repository.crud_repository import CrudRepository from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.method_query_builder import ( _MetodQueryBuilder, @@ -150,14 +151,15 @@ def _get_sql_statement( query = query.where(filter_condition_stack.pop()) return query + @Transactional def _session_execute(self, statement: SelectOfScalar, is_one_result: bool) -> Any: - with PySpringModel.create_session() as session: - logger.debug(f"Executing query: \n{str(statement)}") - result = ( - session.exec(statement).first() - if is_one_result - else session.exec(statement).fetchall() - ) + session = SessionContextHolder.get_or_create_session() + logger.debug(f"Executing query: \n{str(statement)}") + result = ( + session.exec(statement).first() + if is_one_result + else session.exec(statement).fetchall() + ) return result def post_construct(self) -> None: diff --git a/py_spring_model/py_spring_model_rest/service/py_spring_model_rest_service.py b/py_spring_model/py_spring_model_rest/service/py_spring_model_rest_service.py index e619af2..2f81d85 100644 --- a/py_spring_model/py_spring_model_rest/service/py_spring_model_rest_service.py +++ b/py_spring_model/py_spring_model_rest/service/py_spring_model_rest_service.py @@ -4,6 +4,7 @@ from py_spring_core import Component from py_spring_model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional ID = TypeVar("ID", int, UUID) ModelT = TypeVar("ModelT", bound=PySpringModel) @@ -20,44 +21,48 @@ class PySpringModelRestService(Component): 5. Updating an existing model, 6. Deleting a model by ID. """ - def get_all_models(self) -> dict[str, type[PySpringModel]]: return PySpringModel.get_model_lookup() + @Transactional def get(self, model_type: Type[ModelT], id: ID) -> Optional[ModelT]: - with PySpringModel.create_managed_session() as session: - return session.get(model_type, id) # type: ignore + session = SessionContextHolder.get_or_create_session() + return session.get(model_type, id) # type: ignore + @Transactional def get_all_by_ids(self, model_type: Type[ModelT], ids: list[ID]) -> list[ModelT]: - with PySpringModel.create_managed_session() as session: - return session.query(model_type).filter(model_type.id.in_(ids)).all() # type: ignore - + session = SessionContextHolder.get_or_create_session() + return session.query(model_type).filter(model_type.id.in_(ids)).all() # type: ignore + @Transactional def get_all( self, model_type: Type[ModelT], limit: int, offset: int ) -> list[ModelT]: - with PySpringModel.create_managed_session() as session: - return session.query(model_type).offset(offset).limit(limit).all() + session = SessionContextHolder.get_or_create_session() + return session.query(model_type).offset(offset).limit(limit).all() + @Transactional def create(self, model: ModelT) -> ModelT: - with PySpringModel.create_managed_session() as session: - session.add(model) - return model + session = SessionContextHolder.get_or_create_session() + session.add(model) + return model + @Transactional def update(self, id: ID, model: ModelT) -> Optional[ModelT]: - with PySpringModel.create_managed_session() as session: - model_type = type(model) - primary_keys = PySpringModel.get_primary_key_columns(model_type) - optional_model = session.get(model_type, id) # type: ignore - if optional_model is None: - return - - for key, value in model.model_dump().items(): - if key in primary_keys: - continue - setattr(optional_model, key, value) - session.add(optional_model) + session = SessionContextHolder.get_or_create_session() + model_type = type(model) + primary_keys = PySpringModel.get_primary_key_columns(model_type) + optional_model = session.get(model_type, id) # type: ignore + if optional_model is None: + return + + for key, value in model.model_dump().items(): + if key in primary_keys: + continue + setattr(optional_model, key, value) + session.add(optional_model) + @Transactional def delete(self, model_type: Type[ModelT], id: ID) -> None: - with PySpringModel.create_managed_session() as session: - session.query(model_type).filter(model_type.id == id).delete() # type: ignore - session.commit() + session = SessionContextHolder.get_or_create_session() + session.query(model_type).filter(model_type.id == id).delete() # type: ignore + session.commit() diff --git a/py_spring_model/repository/crud_repository.py b/py_spring_model/repository/crud_repository.py index 5aeae0c..131492f 100644 --- a/py_spring_model/repository/crud_repository.py +++ b/py_spring_model/repository/crud_repository.py @@ -16,6 +16,7 @@ from sqlmodel.sql.expression import Select, SelectOfScalar from py_spring_model.core.model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional from py_spring_model.repository.repository_base import RepositoryBase T = TypeVar("T", bound=PySpringModel) @@ -58,123 +59,135 @@ def __init__(self) -> None: def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]: return get_args(tp=cls.__mro__[0].__orig_bases__[0]) + @Transactional def _find_by_statement( self, statement: Union[Select, SelectOfScalar], - session: Optional[Session] = None, - ) -> tuple[Session, Optional[T]]: - session = session or self._create_session() + ) -> Optional[T]: + session = SessionContextHolder.get_or_create_session() - return session, session.exec(statement).first() + return session.exec(statement).first() + @Transactional def _find_by_query( self, query_by: dict[str, Any], - session: Optional[Session] = None, - ) -> tuple[Session, Optional[T]]: - session = session or self._create_session() + ) -> Optional[T]: + session = SessionContextHolder.get_or_create_session() statement = select(self.model_class).filter_by(**query_by) - return session, session.exec(statement).first() + return session.exec(statement).first() + + @Transactional def _find_all_by_query( self, query_by: dict[str, Any], - session: Optional[Session] = None, ) -> tuple[Session, list[T]]: - session = session or self._create_session() + session = SessionContextHolder.get_or_create_session() statement = select(self.model_class).filter_by(**query_by) return session, list(session.exec(statement).fetchall()) + @Transactional def _find_all_by_statement( self, statement: Union[Select, SelectOfScalar], - session: Optional[Session] = None, - ) -> tuple[Session, list[T]]: - session = session or self._create_session() - return session, list(session.exec(statement).fetchall()) + ) -> list[T]: + session = SessionContextHolder.get_or_create_session() + return list(session.exec(statement).fetchall()) + @Transactional def find_by_id(self, id: ID) -> Optional[T]: - with self.create_managed_session() as session: - statement = select(self.model_class).where(self.model_class.id == id) # type: ignore - optional_entity = session.exec(statement).first() - if optional_entity is None: - return - - return optional_entity.clone() # type: ignore - + session = SessionContextHolder.get_or_create_session() + statement = select(self.model_class).where(self.model_class.id == id) # type: ignore + optional_entity = session.exec(statement).first() + if optional_entity is None: + return + + return optional_entity + + @Transactional def find_all_by_ids(self, ids: list[ID]) -> list[T]: - with self.create_managed_session() as session: - statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore - return [entity.clone() for entity in session.exec(statement).all()] # type: ignore + session = SessionContextHolder.get_or_create_session() + statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore + return [entity for entity in session.exec(statement).all()] # type: ignore + @Transactional def find_all(self) -> list[T]: - with self.create_managed_session() as session: - statement = select(self.model_class) # type: ignore - return [entity.clone() for entity in session.exec(statement).all()] # type: ignore + session = SessionContextHolder.get_or_create_session() + statement = select(self.model_class) # type: ignore + return [entity for entity in session.exec(statement).all()] # type: ignore + @Transactional def save(self, entity: T) -> T: - with self.create_managed_session() as session: - session.add(entity) - return entity.clone() # type: ignore + session = SessionContextHolder.get_or_create_session() + session.add(entity) + return entity + @Transactional def save_all( self, entities: Iterable[T], ) -> bool: - with self.create_managed_session() as session: - session.add_all(entities) + session = SessionContextHolder.get_or_create_session() + session.add_all(entities) return True + @Transactional def delete(self, entity: T) -> bool: - with self.create_managed_session() as session: - _, optional_intance = self._find_by_query(entity.model_dump(), session) - if optional_intance is None: - return False - session.delete(optional_intance) + session = SessionContextHolder.get_or_create_session() + optional_instance = self._find_by_query(entity.model_dump()) + if optional_instance is None: + return False + session.delete(optional_instance) return True + @Transactional def delete_all(self, entities: Iterable[T]) -> bool: - with self.create_managed_session() as session: - ids = [entity.id for entity in entities] # type: ignore - - statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore - _, deleted_entities = self._find_all_by_statement(statement, session) - if deleted_entities is None: - return False - - for entity in deleted_entities: - session.delete(entity) + session = SessionContextHolder.get_or_create_session() + ids = [entity.id for entity in entities] # type: ignore + + statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore + deleted_entities = self._find_all_by_statement(statement) + if deleted_entities is None: + return False + + for entity in deleted_entities: + session.delete(entity) return True + + @Transactional def delete_by_id(self, _id: ID) -> bool: - with self.create_managed_session() as session: - _, entity = self._find_by_query({"id": _id}, session) - if entity is None: - return False - session.delete(entity) + session = SessionContextHolder.get_or_create_session() + entity = self._find_by_query({"id": _id}) + if entity is None: + return False + session.delete(entity) return True + @Transactional def delete_all_by_ids(self, ids: list[ID]) -> bool: - with self.create_managed_session() as session: - statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore - _, deleted_entities = self._find_all_by_statement(statement, session) - if deleted_entities is None: - return False - for entity in deleted_entities: - session.delete(entity) - return True + session = SessionContextHolder.get_or_create_session() + statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore + deleted_entities = self._find_all_by_statement(statement) + if deleted_entities is None: + return False + for entity in deleted_entities: + session.delete(entity) + return True + @Transactional def upsert(self, entity: T, query_by: dict[str, Any]) -> T: - with self.create_managed_session() as session: - statement = select(self.model_class).filter_by(**query_by) # type: ignore - _, existing_entity = self._find_by_statement(statement, session) - if existing_entity is not None: - # If the entity exists, update its attributes - for key, value in entity.model_dump().items(): - setattr(existing_entity, key, value) - session.add(existing_entity) - else: - # If the entity does not exist, insert it - session.add(entity) - return entity + session = SessionContextHolder.get_or_create_session() + statement = select(self.model_class).filter_by(**query_by) # type: ignore + existing_entity = self._find_by_statement(statement) + if existing_entity is not None: + # If the entity exists, update its attributes + for key, value in entity.model_dump().items(): + setattr(existing_entity, key, value) + session.add(existing_entity) + else: + # If the entity does not exist, insert it + session.add(entity) + return entity diff --git a/tests/test_crud_repository.py b/tests/test_crud_repository.py index c2e2af1..17d7234 100644 --- a/tests/test_crud_repository.py +++ b/tests/test_crud_repository.py @@ -4,6 +4,7 @@ from sqlmodel import Field, SQLModel from py_spring_model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder from py_spring_model.repository.crud_repository import CrudRepository class User(PySpringModel, table=True): @@ -19,11 +20,13 @@ def setup_method(self): logger.info("Setting up test environment...") self.engine = create_engine("sqlite:///:memory:", echo=True) PySpringModel._engine = self.engine + SessionContextHolder.clear_session() SQLModel.metadata.create_all(self.engine) def teardown_method(self): logger.info("Tearing down test environment...") SQLModel.metadata.drop_all(self.engine) + SessionContextHolder.clear_session() @pytest.fixture def user_repository(self): @@ -57,12 +60,12 @@ def test_find_all(self, user_repository: UserRepository): def test_find_by_query(self, user_repository: UserRepository): self.create_test_user(user_repository) - _, user = user_repository._find_by_query({"name": "John Doe"}) + user = user_repository._find_by_query({"name": "John Doe"}) assert user is not None assert user.id == 1 assert user.name == "John Doe" - _, email_user = user_repository._find_by_query({"email": "john@example.com"}) + email_user = user_repository._find_by_query({"email": "john@example.com"}) assert email_user is not None assert email_user.id == 1 assert email_user.email == "john@example.com" @@ -132,4 +135,26 @@ def test_upsert_for_new_user(self, user_repository: UserRepository): new_user = user_repository.find_by_id(1) assert new_user is not None assert new_user.name == "John Doe" - assert new_user.email == "john@example.com" \ No newline at end of file + assert new_user.email == "john@example.com" + + def test_update_user(self, user_repository: UserRepository): + self.create_test_user(user_repository) + user = user_repository.find_by_id(1) + assert user is not None + user.name = "William Chen" + user.email = "william.chen@example.com" + user_repository.save(user) + updated_user = user_repository.find_by_id(1) + assert updated_user is not None + assert updated_user.name == "William Chen" + assert updated_user.email == "william.chen@example.com" + + def test_delete_user_with_user_found(self, user_repository: UserRepository): + self.create_test_user(user_repository) + user = user_repository.find_by_id(1) + assert user is not None + assert user_repository.delete(user) + assert user_repository.find_by_id(1) is None + + def test_delete_user_with_user_not_found(self, user_repository: UserRepository): + assert user_repository.delete(User(id=1, name="John Doe", email="john@example.com")) is False \ No newline at end of file diff --git a/tests/test_crud_repository_implementation_service.py b/tests/test_crud_repository_implementation_service.py index 77f5428..1be059c 100644 --- a/tests/test_crud_repository_implementation_service.py +++ b/tests/test_crud_repository_implementation_service.py @@ -6,6 +6,7 @@ from sqlalchemy import create_engine from sqlmodel import SQLModel from py_spring_model import PySpringModel, Field, CrudRepository, Query +from py_spring_model.core.session_context_holder import SessionContextHolder from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.crud_repository_implementation_service import CrudRepositoryImplementationService from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.method_query_builder import _MetodQueryBuilder @@ -32,11 +33,13 @@ def setup_method(self): logger.info("Setting up test environment...") self.engine = create_engine("sqlite:///:memory:", echo=True) PySpringModel._engine = self.engine + SessionContextHolder.clear_session() SQLModel.metadata.create_all(self.engine) def teardown_method(self): logger.info("Tearing down test environment...") SQLModel.metadata.drop_all(self.engine) + SessionContextHolder.clear_session() @pytest.fixture def user_repository(self): diff --git a/tests/test_query_modifying_operations.py b/tests/test_query_modifying_operations.py index f06a422..40416c2 100644 --- a/tests/test_query_modifying_operations.py +++ b/tests/test_query_modifying_operations.py @@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock from py_spring_model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder from py_spring_model.py_spring_model_rest.service.query_service.query import Query, QueryExecutionService from py_spring_model.repository.crud_repository import CrudRepository @@ -88,14 +89,17 @@ def setup_method(self): PySpringModel.set_engine(self.engine) PySpringModel.set_metadata(SQLModel.metadata) PySpringModel.set_models([TestUser]) + SessionContextHolder.clear_session() SQLModel.metadata.create_all(self.engine) self.repository = TestUserRepository() + self.repository.insert_user_with_commit(name="John Doe", email="john@example.com", age=30) def teardown_method(self): """Clean up test environment""" logger.info("Tearing down test environment...") SQLModel.metadata.drop_all(self.engine) + SessionContextHolder.clear_session() def test_insert_with_commit_true(self): """Test INSERT operation with is_modifying=True (should commit)""" diff --git a/tests/test_session_depth.py b/tests/test_session_depth.py new file mode 100644 index 0000000..bb4c687 --- /dev/null +++ b/tests/test_session_depth.py @@ -0,0 +1,112 @@ +import pytest +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional +from py_spring_model.core.model import PySpringModel + + +class TestSessionDepth: + """Test the explicit session depth tracking functionality""" + + def setup_method(self): + """Clean up any existing sessions before each test""" + SessionContextHolder.clear_session() + + def teardown_method(self): + """Clean up after each test""" + SessionContextHolder.clear_session() + + def test_session_depth_starts_at_zero(self): + """Test that session depth starts at 0""" + assert SessionContextHolder.get_session_depth() == 0 + + def test_session_depth_increments_and_decrements(self): + """Test that session depth properly increments and decrements""" + # Initially 0 + assert SessionContextHolder.get_session_depth() == 0 + + # Enter first level + depth1 = SessionContextHolder.enter_session() + assert depth1 == 1 + assert SessionContextHolder.get_session_depth() == 1 + + # Enter second level + depth2 = SessionContextHolder.enter_session() + assert depth2 == 2 + assert SessionContextHolder.get_session_depth() == 2 + + # Exit second level + depth_after_exit1 = SessionContextHolder.exit_session() + assert depth_after_exit1 == 1 + assert SessionContextHolder.get_session_depth() == 1 + + # Exit first level + depth_after_exit2 = SessionContextHolder.exit_session() + assert depth_after_exit2 == 0 + assert SessionContextHolder.get_session_depth() == 0 + + def test_session_cleared_only_at_outermost_level(self): + """Test that session is only cleared when depth reaches 0""" + # Enter first level and create session + SessionContextHolder.enter_session() + session = SessionContextHolder.get_or_create_session() + assert SessionContextHolder.has_session() + + # Enter second level - session should still exist + SessionContextHolder.enter_session() + assert SessionContextHolder.has_session() + assert SessionContextHolder.get_session_depth() == 2 + + # Exit second level - session should still exist + SessionContextHolder.exit_session() + assert SessionContextHolder.has_session() + assert SessionContextHolder.get_session_depth() == 1 + + # Exit first level - session should be cleared + SessionContextHolder.exit_session() + assert not SessionContextHolder.has_session() + assert SessionContextHolder.get_session_depth() == 0 + + def test_clear_session_resets_depth(self): + """Test that clear_session() resets the depth to 0""" + SessionContextHolder.enter_session() + SessionContextHolder.enter_session() + assert SessionContextHolder.get_session_depth() == 2 + + SessionContextHolder.clear_session() + assert SessionContextHolder.get_session_depth() == 0 + + @pytest.mark.parametrize("nesting_levels", [1, 2, 3, 5]) + def test_transactional_depth_tracking(self, nesting_levels): + """Test that @Transactional properly tracks depth at various nesting levels""" + depth_records = [] + + def create_nested_function(level: int): + @Transactional + def nested_func(): + current_depth = SessionContextHolder.get_session_depth() + depth_records.append(current_depth) + if level > 1: + create_nested_function(level - 1)() + return nested_func + + # Create and execute nested function + outermost_func = create_nested_function(nesting_levels) + outermost_func() + + # Verify depth progression + assert len(depth_records) == nesting_levels + for i, recorded_depth in enumerate(depth_records): + expected_depth = i + 1 + assert recorded_depth == expected_depth, f"At level {i+1}, expected depth {expected_depth}, got {recorded_depth}" + + # Verify session is cleaned up + assert SessionContextHolder.get_session_depth() == 0 + assert not SessionContextHolder.has_session() + + def test_depth_prevents_negative_values(self): + """Test that depth counter prevents going below 0""" + assert SessionContextHolder.get_session_depth() == 0 + + # Try to exit when already at 0 + new_depth = SessionContextHolder.exit_session() + assert new_depth == 0 + assert SessionContextHolder.get_session_depth() == 0 \ No newline at end of file diff --git a/tests/test_transactional_decorator.py b/tests/test_transactional_decorator.py new file mode 100644 index 0000000..e219cf4 --- /dev/null +++ b/tests/test_transactional_decorator.py @@ -0,0 +1,356 @@ +import pytest +from unittest.mock import patch +from sqlalchemy import create_engine, text +from sqlmodel import Field, SQLModel + +from py_spring_model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional + + +class TransactionalTestUser(PySpringModel, table=True): + """Test model for transactional operations""" + id: int = Field(default=None, primary_key=True) + name: str + email: str + age: int = Field(default=0) + + +class TestTransactionalDecorator: + """Test suite for the @Transactional decorator""" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + """Set up test environment with in-memory SQLite database""" + self.engine = create_engine("sqlite:///:memory:", echo=False) + PySpringModel.set_engine(self.engine) + PySpringModel.set_metadata(SQLModel.metadata) + PySpringModel.set_models([TransactionalTestUser]) + + # Clear any existing session + SessionContextHolder.clear_session() + + SQLModel.metadata.create_all(self.engine) + + def teardown_method(self): + """Tear down test environment""" + SQLModel.metadata.drop_all(self.engine) + SessionContextHolder.clear_session() + PySpringModel._engine = None + PySpringModel._metadata = None + PySpringModel._connection = None + + def test_single_transactional_success(self): + """Test that a single @Transactional function commits successfully""" + + @Transactional + def create_user(): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name="John Doe", email="john@example.com", age=30) + session.add(user) + session.flush() # To get the ID + return user + + # Execute the transactional function + result = create_user() + + # Verify the user was created and committed + assert result.name == "John Doe" + assert result.email == "john@example.com" + assert result.age == 30 + + # Verify session is cleared after transaction + assert not SessionContextHolder.has_session() + + # Verify data persisted to database + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall() + assert len(users) == 1 + assert users[0].name == "John Doe" + + def test_single_transactional_rollback(self): + """Test that a single @Transactional function rolls back on exception""" + + @Transactional + def create_user_with_error(): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name="Jane Doe", email="jane@example.com", age=25) + session.add(user) + session.flush() + raise ValueError("Simulated error") + + # Execute the transactional function and expect exception + with pytest.raises(ValueError, match="Simulated error"): + create_user_with_error() + + # Verify session is cleared after rollback + assert not SessionContextHolder.has_session() + + # Verify no data persisted to database + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall() + assert len(users) == 0 + + def test_nested_transactional_success(self): + """Test that nested @Transactional functions share the same session and commit at top level""" + + @Transactional + def create_user(name: str, email: str): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name=name, email=email, age=30) + session.add(user) + session.flush() + return user + + @Transactional + def update_user_age(user_id: int, new_age: int): + session = SessionContextHolder.get_or_create_session() + session.execute(text(f"UPDATE transactionaltestuser SET age = {new_age} WHERE id = {user_id}")) + + @Transactional + def create_and_update_user(): + # This should share the same session with nested calls + user = create_user("Alice Smith", "alice@example.com") + update_user_age(user.id, 35) + return user + + # Execute the outer transactional function + result = create_and_update_user() + + # Verify session is cleared after transaction + assert not SessionContextHolder.has_session() + + # Verify both operations were committed + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall() + assert len(users) == 1 + assert users[0].name == "Alice Smith" + assert users[0].age == 35 # Updated by nested function + + def test_nested_transactional_rollback_from_inner(self): + """Test that exception in nested @Transactional causes rollback at top level""" + + @Transactional + def create_user(name: str, email: str): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name=name, email=email, age=30) + session.add(user) + session.flush() + return user + + @Transactional + def update_user_with_error(user_id: int): + session = SessionContextHolder.get_or_create_session() + session.execute(text(f"UPDATE transactionaltestuser SET age = 40 WHERE id = {user_id}")) + raise RuntimeError("Update failed") + + @Transactional + def create_and_update_user_with_error(): + user = create_user("Bob Johnson", "bob@example.com") + update_user_with_error(user.id) # This will raise an exception + return user + + # Execute and expect exception + with pytest.raises(RuntimeError, match="Update failed"): + create_and_update_user_with_error() + + # Verify session is cleared after rollback + assert not SessionContextHolder.has_session() + + # Verify no data persisted (everything rolled back) + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall() + assert len(users) == 0 + + def test_nested_transactional_rollback_from_outer(self): + """Test that exception in outer @Transactional causes rollback after nested calls""" + + @Transactional + def create_user(name: str, email: str): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name=name, email=email, age=30) + session.add(user) + session.flush() + return user + + @Transactional + def update_user_age(user_id: int, new_age: int): + session = SessionContextHolder.get_or_create_session() + session.execute(text(f"UPDATE transactionaltestuser SET age = {new_age} WHERE id = {user_id}")) + + @Transactional + def create_update_and_fail(): + user = create_user("Charlie Brown", "charlie@example.com") + update_user_age(user.id, 45) + # Both nested operations succeeded, but outer fails + raise Exception("Outer operation failed") + + # Execute and expect exception + with pytest.raises(Exception, match="Outer operation failed"): + create_update_and_fail() + + # Verify session is cleared after rollback + assert not SessionContextHolder.has_session() + + # Verify no data persisted (everything rolled back) + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall() + assert len(users) == 0 + + def test_session_sharing_across_nested_transactions(self): + """Test that nested @Transactional functions share the same session instance""" + + captured_sessions = [] + + @Transactional + def inner_function(): + session = SessionContextHolder.get_or_create_session() + captured_sessions.append(session) + user = TransactionalTestUser(name="Inner User", email="inner@example.com", age=25) + session.add(user) + session.flush() + return session + + @Transactional + def middle_function(): + session = SessionContextHolder.get_or_create_session() + captured_sessions.append(session) + inner_session = inner_function() + return session, inner_session + + @Transactional + def outer_function(): + session = SessionContextHolder.get_or_create_session() + captured_sessions.append(session) + middle_session, inner_session = middle_function() + return session, middle_session, inner_session + + # Execute nested transactions + outer_session, middle_session, inner_session = outer_function() + + # Verify all functions used the same session instance + assert len(captured_sessions) == 3 + assert captured_sessions[0] is captured_sessions[1] # outer and middle + assert captured_sessions[1] is captured_sessions[2] # middle and inner + assert outer_session is middle_session is inner_session + + # Verify session is cleared after transaction + assert not SessionContextHolder.has_session() + + def test_transactional_session_context_isolation(self): + """Test that @Transactional properly isolates session context""" + + @Transactional + def first_transaction(): + session = SessionContextHolder.get_or_create_session() + user1 = TransactionalTestUser(name="User 1", email="user1@example.com", age=30) + session.add(user1) + session.flush() + return user1.id + + @Transactional + def second_transaction(): + session = SessionContextHolder.get_or_create_session() + user2 = TransactionalTestUser(name="User 2", email="user2@example.com", age=35) + session.add(user2) + session.flush() + return user2.id + + # Execute separate transactions + first_transaction() + second_transaction() + + # Verify sessions were properly isolated + assert not SessionContextHolder.has_session() + + # Verify both transactions committed separately + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser ORDER BY id")).fetchall() + assert len(users) == 2 + assert users[0].name == "User 1" + assert users[1].name == "User 2" + + def test_transactional_preserves_function_metadata(self): + """Test that @Transactional preserves original function metadata""" + + @Transactional + def documented_function(param1: str, param2: int = 10) -> str: + """This is a documented function with parameters.""" + return f"{param1}_{param2}" + + # Verify function metadata is preserved + assert documented_function.__name__ == "documented_function" + assert documented_function.__doc__ is not None and "documented function" in documented_function.__doc__ + + # Verify function still works correctly + result = documented_function("test", 20) + assert result == "test_20" + + def test_transactional_commit_rollback_behavior(self): + """Test the core commit/rollback behavior of nested transactions""" + + commit_calls = [] + rollback_calls = [] + + # Mock session to track commit/rollback calls + original_session_class = PySpringModel.create_session + + def mock_create_session(): + session = original_session_class() + original_commit = session.commit + original_rollback = session.rollback + + def mock_commit(): + commit_calls.append("commit") + return original_commit() + + def mock_rollback(): + rollback_calls.append("rollback") + return original_rollback() + + session.commit = mock_commit + session.rollback = mock_rollback + return session + + @Transactional + def inner_operation(): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name="Test", email="test@example.com", age=30) + session.add(user) + session.flush() + + @Transactional + def outer_operation(): + inner_operation() + + # Test successful nested transaction + with patch.object(PySpringModel, 'create_session', side_effect=mock_create_session): + outer_operation() + + # Only the outermost transaction should commit + assert len(commit_calls) == 1 + assert len(rollback_calls) == 0 + + # Reset counters + commit_calls.clear() + rollback_calls.clear() + + @Transactional + def inner_operation_with_error(): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name="Test2", email="test2@example.com", age=25) + session.add(user) + session.flush() + raise ValueError("Inner error") + + @Transactional + def outer_operation_with_error(): + inner_operation_with_error() + + # Test failed nested transaction + with patch.object(PySpringModel, 'create_session', side_effect=mock_create_session): + with pytest.raises(ValueError): + outer_operation_with_error() + + # Only the outermost transaction should rollback + assert len(commit_calls) == 0 + assert len(rollback_calls) == 1 \ No newline at end of file