From a7f4a65820f72d57f1d058bb5cc34efc2c49fab1 Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 19 Jun 2025 21:45:01 +0800 Subject: [PATCH 01/15] Implement Session Management and Transactional Decorator for CRUD Operations - Introduced `SessionContextHolder` to manage SQLAlchemy sessions using context variables. - Added `Transactional` decorator to ensure session commits and rollbacks during CRUD operations. - Refactored `CrudRepository` methods to utilize the new session management, enhancing code clarity and reducing session handling boilerplate. - Updated tests to reflect changes in the repository methods, ensuring proper functionality without session parameters. --- py_spring_model/__init__.py | 1 + .../core/session_context_holder.py | 50 ++++++ py_spring_model/repository/crud_repository.py | 149 +++++++++--------- tests/test_crud_repository.py | 4 +- 4 files changed, 130 insertions(+), 74 deletions(-) create mode 100644 py_spring_model/core/session_context_holder.py diff --git a/py_spring_model/__init__.py b/py_spring_model/__init__.py index 64c7753..b14de83 100644 --- a/py_spring_model/__init__.py +++ b/py_spring_model/__init__.py @@ -1,4 +1,5 @@ from py_spring_model.core.model import PySpringModel, Field +from py_spring_model.core.session_context_holder import SessionContextHolder from py_spring_model.py_spring_model_provider import provide_py_spring_model from py_spring_model.repository.crud_repository import CrudRepository from py_spring_model.repository.repository_base import RepositoryBase diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py new file mode 100644 index 0000000..2b96421 --- /dev/null +++ b/py_spring_model/core/session_context_holder.py @@ -0,0 +1,50 @@ +from contextvars import ContextVar +from functools import wraps +from typing import ClassVar, Optional +from sqlmodel import Session + +from py_spring_model.core.model import PySpringModel + +def Transactional(func): + """ + A decorator that wraps a function and commits the session if the function is successful. + If the function raises an exception, the session is rolled back. + The session is then closed. + """ + @wraps(func) + def wrapper(*args, **kwargs): + session = SessionContextHolder.get_or_create_session() + try: + result = func(*args, **kwargs) + session.commit() + return result + except Exception: + session.rollback() + raise + finally: + SessionContextHolder.clear_session() + return wrapper + +class SessionContextHolder: + """ + A context holder for the session. + This is used to store the session in a context variable so that it can be accessed by the query service. + This is useful for the query service to be able to access the session without having to pass it in as an argument. + This is also useful for the query service to be able to access the session without having to pass it in as an argument. + """ + _session: ClassVar[ContextVar[Optional[Session]]] = ContextVar("session", default=None) + @classmethod + def get_or_create_session(cls) -> Session: + optional_session = cls._session.get() + if optional_session is None: + session = PySpringModel.create_session() + cls._session.set(session) + return session + return optional_session + @classmethod + def clear_session(cls): + optional_session = cls._session.get() + if optional_session is None: + return + optional_session.close() + cls._session.set(None) \ No newline at end of file diff --git a/py_spring_model/repository/crud_repository.py b/py_spring_model/repository/crud_repository.py index 5aeae0c..c439934 100644 --- a/py_spring_model/repository/crud_repository.py +++ b/py_spring_model/repository/crud_repository.py @@ -16,6 +16,7 @@ from sqlmodel.sql.expression import Select, SelectOfScalar from py_spring_model.core.model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional from py_spring_model.repository.repository_base import RepositoryBase T = TypeVar("T", bound=PySpringModel) @@ -61,120 +62,124 @@ def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]: def _find_by_statement( self, statement: Union[Select, SelectOfScalar], - session: Optional[Session] = None, - ) -> tuple[Session, Optional[T]]: - session = session or self._create_session() + ) -> Optional[T]: + session = SessionContextHolder.get_or_create_session() + + return session.exec(statement).first() - return session, session.exec(statement).first() def _find_by_query( self, query_by: dict[str, Any], - session: Optional[Session] = None, - ) -> tuple[Session, Optional[T]]: - session = session or self._create_session() + ) -> Optional[T]: + session = SessionContextHolder.get_or_create_session() statement = select(self.model_class).filter_by(**query_by) - return session, session.exec(statement).first() + return session.exec(statement).first() def _find_all_by_query( self, query_by: dict[str, Any], - session: Optional[Session] = None, ) -> tuple[Session, list[T]]: - session = session or self._create_session() + session = SessionContextHolder.get_or_create_session() statement = select(self.model_class).filter_by(**query_by) return session, list(session.exec(statement).fetchall()) def _find_all_by_statement( self, statement: Union[Select, SelectOfScalar], - session: Optional[Session] = None, - ) -> tuple[Session, list[T]]: - session = session or self._create_session() - return session, list(session.exec(statement).fetchall()) + ) -> list[T]: + session = SessionContextHolder.get_or_create_session() + return list(session.exec(statement).fetchall()) def find_by_id(self, id: ID) -> Optional[T]: - with self.create_managed_session() as session: - statement = select(self.model_class).where(self.model_class.id == id) # type: ignore - optional_entity = session.exec(statement).first() - if optional_entity is None: - return - - return optional_entity.clone() # type: ignore + session = SessionContextHolder.get_or_create_session() + statement = select(self.model_class).where(self.model_class.id == id) # type: ignore + optional_entity = session.exec(statement).first() + if optional_entity is None: + return + return optional_entity def find_all_by_ids(self, ids: list[ID]) -> list[T]: - with self.create_managed_session() as session: - statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore - return [entity.clone() for entity in session.exec(statement).all()] # type: ignore + session = SessionContextHolder.get_or_create_session() + statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore + return [entity for entity in session.exec(statement).all()] # type: ignore def find_all(self) -> list[T]: - with self.create_managed_session() as session: - statement = select(self.model_class) # type: ignore - return [entity.clone() for entity in session.exec(statement).all()] # type: ignore + session = SessionContextHolder.get_or_create_session() + statement = select(self.model_class) # type: ignore + return [entity for entity in session.exec(statement).all()] # type: ignore + @Transactional def save(self, entity: T) -> T: - with self.create_managed_session() as session: - session.add(entity) - return entity.clone() # type: ignore + session = SessionContextHolder.get_or_create_session() + session.add(entity) + return entity + @Transactional def save_all( self, entities: Iterable[T], ) -> bool: - with self.create_managed_session() as session: - session.add_all(entities) + session = SessionContextHolder.get_or_create_session() + session.add_all(entities) return True + @Transactional def delete(self, entity: T) -> bool: - with self.create_managed_session() as session: - _, optional_intance = self._find_by_query(entity.model_dump(), session) - if optional_intance is None: - return False - session.delete(optional_intance) + session = SessionContextHolder.get_or_create_session() + optional_intance = self._find_by_query(entity.model_dump()) + if optional_intance is None: + return False + session.delete(optional_intance) return True + @Transactional def delete_all(self, entities: Iterable[T]) -> bool: - with self.create_managed_session() as session: - ids = [entity.id for entity in entities] # type: ignore - - statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore - _, deleted_entities = self._find_all_by_statement(statement, session) - if deleted_entities is None: - return False - - for entity in deleted_entities: - session.delete(entity) + session = SessionContextHolder.get_or_create_session() + ids = [entity.id for entity in entities] # type: ignore + + statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore + deleted_entities = self._find_all_by_statement(statement) + if deleted_entities is None: + return False + + for entity in deleted_entities: + session.delete(entity) return True + + @Transactional def delete_by_id(self, _id: ID) -> bool: - with self.create_managed_session() as session: - _, entity = self._find_by_query({"id": _id}, session) - if entity is None: - return False - session.delete(entity) + session = SessionContextHolder.get_or_create_session() + entity = self._find_by_query({"id": _id}) + if entity is None: + return False + session.delete(entity) return True + @Transactional def delete_all_by_ids(self, ids: list[ID]) -> bool: - with self.create_managed_session() as session: - statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore - _, deleted_entities = self._find_all_by_statement(statement, session) - if deleted_entities is None: - return False - for entity in deleted_entities: - session.delete(entity) - return True + session = SessionContextHolder.get_or_create_session() + statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore + deleted_entities = self._find_all_by_statement(statement) + if deleted_entities is None: + return False + for entity in deleted_entities: + session.delete(entity) + return True + @Transactional def upsert(self, entity: T, query_by: dict[str, Any]) -> T: - with self.create_managed_session() as session: - statement = select(self.model_class).filter_by(**query_by) # type: ignore - _, existing_entity = self._find_by_statement(statement, session) - if existing_entity is not None: - # If the entity exists, update its attributes - for key, value in entity.model_dump().items(): - setattr(existing_entity, key, value) - session.add(existing_entity) - else: - # If the entity does not exist, insert it - session.add(entity) - return entity + session = SessionContextHolder.get_or_create_session() + statement = select(self.model_class).filter_by(**query_by) # type: ignore + existing_entity = self._find_by_statement(statement) + if existing_entity is not None: + # If the entity exists, update its attributes + for key, value in entity.model_dump().items(): + setattr(existing_entity, key, value) + session.add(existing_entity) + else: + # If the entity does not exist, insert it + session.add(entity) + return entity diff --git a/tests/test_crud_repository.py b/tests/test_crud_repository.py index c2e2af1..2f751c6 100644 --- a/tests/test_crud_repository.py +++ b/tests/test_crud_repository.py @@ -57,12 +57,12 @@ def test_find_all(self, user_repository: UserRepository): def test_find_by_query(self, user_repository: UserRepository): self.create_test_user(user_repository) - _, user = user_repository._find_by_query({"name": "John Doe"}) + user = user_repository._find_by_query({"name": "John Doe"}) assert user is not None assert user.id == 1 assert user.name == "John Doe" - _, email_user = user_repository._find_by_query({"email": "john@example.com"}) + email_user = user_repository._find_by_query({"email": "john@example.com"}) assert email_user is not None assert email_user.id == 1 assert email_user.email == "john@example.com" From 708906da6c16a48ad1a9dfaf799bc19364ddfe21 Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 19 Jun 2025 22:03:35 +0800 Subject: [PATCH 02/15] Refactor CRUD Operations to Utilize SessionContextHolder for Session Management - Updated CRUD methods in `PySpringModelRestService` and `CrudRepositoryImplementationService` to use `SessionContextHolder` for session handling, replacing the previous context manager approach. - Added `@Transactional` decorator to relevant methods to ensure proper transaction management. - Enhanced error handling in the `Transactional` decorator to raise specific exceptions. - Updated tests to clear session state before and after tests to maintain isolation and prevent side effects. --- .../core/session_context_holder.py | 4 +- py_spring_model/py_spring_model_provider.py | 2 +- .../crud_repository_implementation_service.py | 16 +++--- .../service/py_spring_model_rest_service.py | 57 ++++++++++--------- tests/test_crud_repository.py | 3 + ..._crud_repository_implementation_service.py | 3 + tests/test_query_modifying_operations.py | 4 ++ 7 files changed, 53 insertions(+), 36 deletions(-) diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py index 2b96421..2d3b234 100644 --- a/py_spring_model/core/session_context_holder.py +++ b/py_spring_model/core/session_context_holder.py @@ -18,9 +18,9 @@ def wrapper(*args, **kwargs): result = func(*args, **kwargs) session.commit() return result - except Exception: + except Exception as error: session.rollback() - raise + raise error finally: SessionContextHolder.clear_session() return wrapper diff --git a/py_spring_model/py_spring_model_provider.py b/py_spring_model/py_spring_model_provider.py index 9c31d33..7f7bdc1 100644 --- a/py_spring_model/py_spring_model_provider.py +++ b/py_spring_model/py_spring_model_provider.py @@ -148,7 +148,7 @@ def provider_init(self) -> None: def provide_py_spring_model() -> EntityProvider: return PySpringModelProvider( rest_controller_classes=[ - # PySpringModelRestController + PySpringModelRestController ], component_classes=[ PySpringModelRestService, diff --git a/py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py b/py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py index efcfd65..45a03d4 100644 --- a/py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py +++ b/py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py @@ -19,6 +19,7 @@ from sqlmodel.sql.expression import SelectOfScalar from py_spring_model.core.model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional from py_spring_model.repository.crud_repository import CrudRepository from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.method_query_builder import ( _MetodQueryBuilder, @@ -150,14 +151,15 @@ def _get_sql_statement( query = query.where(filter_condition_stack.pop()) return query + @Transactional def _session_execute(self, statement: SelectOfScalar, is_one_result: bool) -> Any: - with PySpringModel.create_session() as session: - logger.debug(f"Executing query: \n{str(statement)}") - result = ( - session.exec(statement).first() - if is_one_result - else session.exec(statement).fetchall() - ) + session = SessionContextHolder.get_or_create_session() + logger.debug(f"Executing query: \n{str(statement)}") + result = ( + session.exec(statement).first() + if is_one_result + else session.exec(statement).fetchall() + ) return result def post_construct(self) -> None: diff --git a/py_spring_model/py_spring_model_rest/service/py_spring_model_rest_service.py b/py_spring_model/py_spring_model_rest/service/py_spring_model_rest_service.py index e619af2..2f81d85 100644 --- a/py_spring_model/py_spring_model_rest/service/py_spring_model_rest_service.py +++ b/py_spring_model/py_spring_model_rest/service/py_spring_model_rest_service.py @@ -4,6 +4,7 @@ from py_spring_core import Component from py_spring_model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional ID = TypeVar("ID", int, UUID) ModelT = TypeVar("ModelT", bound=PySpringModel) @@ -20,44 +21,48 @@ class PySpringModelRestService(Component): 5. Updating an existing model, 6. Deleting a model by ID. """ - def get_all_models(self) -> dict[str, type[PySpringModel]]: return PySpringModel.get_model_lookup() + @Transactional def get(self, model_type: Type[ModelT], id: ID) -> Optional[ModelT]: - with PySpringModel.create_managed_session() as session: - return session.get(model_type, id) # type: ignore + session = SessionContextHolder.get_or_create_session() + return session.get(model_type, id) # type: ignore + @Transactional def get_all_by_ids(self, model_type: Type[ModelT], ids: list[ID]) -> list[ModelT]: - with PySpringModel.create_managed_session() as session: - return session.query(model_type).filter(model_type.id.in_(ids)).all() # type: ignore - + session = SessionContextHolder.get_or_create_session() + return session.query(model_type).filter(model_type.id.in_(ids)).all() # type: ignore + @Transactional def get_all( self, model_type: Type[ModelT], limit: int, offset: int ) -> list[ModelT]: - with PySpringModel.create_managed_session() as session: - return session.query(model_type).offset(offset).limit(limit).all() + session = SessionContextHolder.get_or_create_session() + return session.query(model_type).offset(offset).limit(limit).all() + @Transactional def create(self, model: ModelT) -> ModelT: - with PySpringModel.create_managed_session() as session: - session.add(model) - return model + session = SessionContextHolder.get_or_create_session() + session.add(model) + return model + @Transactional def update(self, id: ID, model: ModelT) -> Optional[ModelT]: - with PySpringModel.create_managed_session() as session: - model_type = type(model) - primary_keys = PySpringModel.get_primary_key_columns(model_type) - optional_model = session.get(model_type, id) # type: ignore - if optional_model is None: - return - - for key, value in model.model_dump().items(): - if key in primary_keys: - continue - setattr(optional_model, key, value) - session.add(optional_model) + session = SessionContextHolder.get_or_create_session() + model_type = type(model) + primary_keys = PySpringModel.get_primary_key_columns(model_type) + optional_model = session.get(model_type, id) # type: ignore + if optional_model is None: + return + + for key, value in model.model_dump().items(): + if key in primary_keys: + continue + setattr(optional_model, key, value) + session.add(optional_model) + @Transactional def delete(self, model_type: Type[ModelT], id: ID) -> None: - with PySpringModel.create_managed_session() as session: - session.query(model_type).filter(model_type.id == id).delete() # type: ignore - session.commit() + session = SessionContextHolder.get_or_create_session() + session.query(model_type).filter(model_type.id == id).delete() # type: ignore + session.commit() diff --git a/tests/test_crud_repository.py b/tests/test_crud_repository.py index 2f751c6..df7854b 100644 --- a/tests/test_crud_repository.py +++ b/tests/test_crud_repository.py @@ -4,6 +4,7 @@ from sqlmodel import Field, SQLModel from py_spring_model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder from py_spring_model.repository.crud_repository import CrudRepository class User(PySpringModel, table=True): @@ -19,11 +20,13 @@ def setup_method(self): logger.info("Setting up test environment...") self.engine = create_engine("sqlite:///:memory:", echo=True) PySpringModel._engine = self.engine + SessionContextHolder.clear_session() SQLModel.metadata.create_all(self.engine) def teardown_method(self): logger.info("Tearing down test environment...") SQLModel.metadata.drop_all(self.engine) + SessionContextHolder.clear_session() @pytest.fixture def user_repository(self): diff --git a/tests/test_crud_repository_implementation_service.py b/tests/test_crud_repository_implementation_service.py index 77f5428..1be059c 100644 --- a/tests/test_crud_repository_implementation_service.py +++ b/tests/test_crud_repository_implementation_service.py @@ -6,6 +6,7 @@ from sqlalchemy import create_engine from sqlmodel import SQLModel from py_spring_model import PySpringModel, Field, CrudRepository, Query +from py_spring_model.core.session_context_holder import SessionContextHolder from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.crud_repository_implementation_service import CrudRepositoryImplementationService from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.method_query_builder import _MetodQueryBuilder @@ -32,11 +33,13 @@ def setup_method(self): logger.info("Setting up test environment...") self.engine = create_engine("sqlite:///:memory:", echo=True) PySpringModel._engine = self.engine + SessionContextHolder.clear_session() SQLModel.metadata.create_all(self.engine) def teardown_method(self): logger.info("Tearing down test environment...") SQLModel.metadata.drop_all(self.engine) + SessionContextHolder.clear_session() @pytest.fixture def user_repository(self): diff --git a/tests/test_query_modifying_operations.py b/tests/test_query_modifying_operations.py index f06a422..40416c2 100644 --- a/tests/test_query_modifying_operations.py +++ b/tests/test_query_modifying_operations.py @@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock from py_spring_model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder from py_spring_model.py_spring_model_rest.service.query_service.query import Query, QueryExecutionService from py_spring_model.repository.crud_repository import CrudRepository @@ -88,14 +89,17 @@ def setup_method(self): PySpringModel.set_engine(self.engine) PySpringModel.set_metadata(SQLModel.metadata) PySpringModel.set_models([TestUser]) + SessionContextHolder.clear_session() SQLModel.metadata.create_all(self.engine) self.repository = TestUserRepository() + self.repository.insert_user_with_commit(name="John Doe", email="john@example.com", age=30) def teardown_method(self): """Clean up test environment""" logger.info("Tearing down test environment...") SQLModel.metadata.drop_all(self.engine) + SessionContextHolder.clear_session() def test_insert_with_commit_true(self): """Test INSERT operation with is_modifying=True (should commit)""" From 6e2631e5af48a84a5a3da0a939b2cd7d6e79e472 Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 19 Jun 2025 22:09:10 +0800 Subject: [PATCH 03/15] Remove commented-out `PySpringModelRestController` from provider configuration for clarity --- py_spring_model/py_spring_model_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_spring_model/py_spring_model_provider.py b/py_spring_model/py_spring_model_provider.py index 7f7bdc1..9c31d33 100644 --- a/py_spring_model/py_spring_model_provider.py +++ b/py_spring_model/py_spring_model_provider.py @@ -148,7 +148,7 @@ def provider_init(self) -> None: def provide_py_spring_model() -> EntityProvider: return PySpringModelProvider( rest_controller_classes=[ - PySpringModelRestController + # PySpringModelRestController ], component_classes=[ PySpringModelRestService, From 4699eaf8493f11db24ffa23bb9c581dc7abbca0d Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 19 Jun 2025 22:20:50 +0800 Subject: [PATCH 04/15] feat: Enhanced @Transactional decorator with smart session management - Implement outermost transaction detection using SessionContextHolder.has_session() - Session lifecycle (commit/rollback/close) only managed by outermost @Transactional - Nested @Transactional methods reuse existing session without interference - Prevent premature session closure in nested transaction scenarios - Maintain transaction integrity across multiple decorated method calls --- .../core/session_context_holder.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py index 2d3b234..80725ec 100644 --- a/py_spring_model/core/session_context_holder.py +++ b/py_spring_model/core/session_context_holder.py @@ -1,28 +1,36 @@ from contextvars import ContextVar from functools import wraps -from typing import ClassVar, Optional +from typing import Any, Callable, ClassVar, Optional from sqlmodel import Session from py_spring_model.core.model import PySpringModel -def Transactional(func): +def Transactional(func: Callable[..., Any]) -> Callable[..., Any]: """ A decorator that wraps a function and commits the session if the function is successful. If the function raises an exception, the session is rolled back. The session is then closed. + If the function is the outermost function, the session is committed. + If the function is not the outermost function, the session is not committed. + If the function is not the outermost function, the session is not rolled back. + If the function is not the outermost function, the session is not closed. """ @wraps(func) def wrapper(*args, **kwargs): + is_outermost = not SessionContextHolder.has_session() session = SessionContextHolder.get_or_create_session() try: result = func(*args, **kwargs) - session.commit() + if is_outermost: + session.commit() return result except Exception as error: - session.rollback() + if is_outermost: + session.rollback() raise error finally: - SessionContextHolder.clear_session() + if is_outermost: + SessionContextHolder.clear_session() return wrapper class SessionContextHolder: @@ -41,10 +49,14 @@ def get_or_create_session(cls) -> Session: cls._session.set(session) return session return optional_session + + @classmethod + def has_session(cls) -> bool: + return cls._session.get() is not None + @classmethod def clear_session(cls): - optional_session = cls._session.get() - if optional_session is None: - return - optional_session.close() + session = cls._session.get() + if session is not None: + session.close() cls._session.set(None) \ No newline at end of file From c79b134173f23f8677dc5785501c0ab3f8d6d3b6 Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 19 Jun 2025 22:22:47 +0800 Subject: [PATCH 05/15] refactor: Rename variable `is_outermost` to `is_outermost_transaction` for improving readability. --- py_spring_model/core/session_context_holder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py index 80725ec..fb76419 100644 --- a/py_spring_model/core/session_context_holder.py +++ b/py_spring_model/core/session_context_holder.py @@ -17,19 +17,19 @@ def Transactional(func: Callable[..., Any]) -> Callable[..., Any]: """ @wraps(func) def wrapper(*args, **kwargs): - is_outermost = not SessionContextHolder.has_session() + is_outermost_transaction = not SessionContextHolder.has_session() session = SessionContextHolder.get_or_create_session() try: result = func(*args, **kwargs) - if is_outermost: + if is_outermost_transaction: session.commit() return result except Exception as error: - if is_outermost: + if is_outermost_transaction: session.rollback() raise error finally: - if is_outermost: + if is_outermost_transaction: SessionContextHolder.clear_session() return wrapper From 7ce0e35c8f2db1aa00cfac6efb61725682165a03 Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 19 Jun 2025 22:33:18 +0800 Subject: [PATCH 06/15] feat: Introduce SessionController and session middleware for improved session management --- .../core/session_context_holder.py | 43 ++++++++++++++++--- py_spring_model/py_spring_model_provider.py | 2 + .../controller/session_controller.py | 38 ++++++++++++++++ 3 files changed, 76 insertions(+), 7 deletions(-) create mode 100644 py_spring_model/py_spring_model_rest/controller/session_controller.py diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py index fb76419..df45196 100644 --- a/py_spring_model/core/session_context_holder.py +++ b/py_spring_model/core/session_context_holder.py @@ -7,13 +7,42 @@ def Transactional(func: Callable[..., Any]) -> Callable[..., Any]: """ - A decorator that wraps a function and commits the session if the function is successful. - If the function raises an exception, the session is rolled back. - The session is then closed. - If the function is the outermost function, the session is committed. - If the function is not the outermost function, the session is not committed. - If the function is not the outermost function, the session is not rolled back. - If the function is not the outermost function, the session is not closed. + Decorator for managing database transactions in a nested-safe manner. + + This decorator ensures that: + - A new session is created only if there is no active session (i.e., outermost transaction). + - The session is committed, rolled back, and closed only by the outermost function. + - Nested transactional functions share the same session and do not interfere with the commit/rollback behavior. + + Behavior Summary: + - If this function is the outermost @Transactional in the call stack: + - A new session is created. + - On success, the session is committed. + - On failure, the session is rolled back. + - The session is closed after execution. + - If this function is called within an existing transaction: + - The existing session is reused. + - No commit, rollback, or close is performed (delegated to the outermost function). + + Example: + @Transactional + def outer_operation(): + create_user() + update_account() + + @Transactional + def create_user(): + db.session.add(User(...)) # Uses same session as outer_operation + + @Transactional + def update_account(): + db.session.add(Account(...)) # Uses same session as outer_operation + + # Only outer_operation will commit or rollback. + # If create_user() or update_account() raises an exception, + # the whole transaction will be rolled back. + + This design is similar to Spring's @Transactional """ @wraps(func) def wrapper(*args, **kwargs): diff --git a/py_spring_model/py_spring_model_provider.py b/py_spring_model/py_spring_model_provider.py index 9c31d33..2c874f2 100644 --- a/py_spring_model/py_spring_model_provider.py +++ b/py_spring_model/py_spring_model_provider.py @@ -10,6 +10,7 @@ from py_spring_model.core.commons import ApplicationFileGroups, PySpringModelProperties from py_spring_model.core.model import PySpringModel +from py_spring_model.py_spring_model_rest.controller.session_controller import SessionController from py_spring_model.repository.repository_base import RepositoryBase from py_spring_model.py_spring_model_rest import PySpringModelRestService from py_spring_model.py_spring_model_rest.controller.py_spring_model_rest_controller import ( @@ -149,6 +150,7 @@ def provide_py_spring_model() -> EntityProvider: return PySpringModelProvider( rest_controller_classes=[ # PySpringModelRestController + SessionController ], component_classes=[ PySpringModelRestService, diff --git a/py_spring_model/py_spring_model_rest/controller/session_controller.py b/py_spring_model/py_spring_model_rest/controller/session_controller.py new file mode 100644 index 0000000..925cd2f --- /dev/null +++ b/py_spring_model/py_spring_model_rest/controller/session_controller.py @@ -0,0 +1,38 @@ +from typing import Awaitable, Callable +from fastapi import Request, Response +from py_spring_model.core.session_context_holder import SessionContextHolder +from py_spring_core import RestController + + +async def session_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + """ + Middleware to ensure that the database session is properly cleaned up after each HTTP request. + + This middleware works with context-based session management using ContextVar. + It guarantees that each request has its own isolated database session context, and + that any session stored in the context is properly closed after the request is handled. + + It does NOT create or commit any transactions by itself. It is meant to be used + in combination with a decorator-based transaction manager (e.g., @Transactional) + which controls when to commit or rollback. + + This middleware acts as a safety net: + - Ensures that the ContextVar-based session is cleared after each request. + - Prevents session leakage between requests in case of unexpected exceptions or unhandled paths. + - Complements nested transaction logic by guaranteeing session cleanup at request boundaries. + + Use this middleware when: + - You are using context-local session handling (via contextvars). + - You want to ensure that long-lived or leaked sessions don't accumulate. + """ + try: + response = await call_next(request) + return response + finally: + SessionContextHolder.clear_session() + +class SessionController(RestController): + def post_construct(self) -> None: + self.app.middleware("http")(session_middleware) + + \ No newline at end of file From 7c7cc55b8fa068542480b909a64174f1a89b780e Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 19 Jun 2025 22:35:48 +0800 Subject: [PATCH 07/15] refactor: Clean up documentation in @Transactional decorator - Removed commented-out explanations for clarity. - Streamlined the docstring to focus on the transaction behavior of the decorator. --- py_spring_model/core/session_context_holder.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py index df45196..4b2e1fb 100644 --- a/py_spring_model/core/session_context_holder.py +++ b/py_spring_model/core/session_context_holder.py @@ -38,11 +38,9 @@ def create_user(): def update_account(): db.session.add(Account(...)) # Uses same session as outer_operation - # Only outer_operation will commit or rollback. - # If create_user() or update_account() raises an exception, - # the whole transaction will be rolled back. - - This design is similar to Spring's @Transactional + Only outer_operation will commit or rollback. + If create_user() or update_account() raises an exception, + the whole transaction will be rolled back. """ @wraps(func) def wrapper(*args, **kwargs): From 4e7b847321e4ceed4307f52ec6edeb4beb9a854c Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 20 Jun 2025 02:08:51 +0800 Subject: [PATCH 08/15] fix: Update SessionContextHolder to use PySpringSession --- py_spring_model/core/session_context_holder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py index 4b2e1fb..b990872 100644 --- a/py_spring_model/core/session_context_holder.py +++ b/py_spring_model/core/session_context_holder.py @@ -1,7 +1,8 @@ from contextvars import ContextVar from functools import wraps from typing import Any, Callable, ClassVar, Optional -from sqlmodel import Session + +from py_spring_model.core.py_spring_session import PySpringSession from py_spring_model.core.model import PySpringModel @@ -67,9 +68,9 @@ class SessionContextHolder: This is useful for the query service to be able to access the session without having to pass it in as an argument. This is also useful for the query service to be able to access the session without having to pass it in as an argument. """ - _session: ClassVar[ContextVar[Optional[Session]]] = ContextVar("session", default=None) + _session: ClassVar[ContextVar[Optional[PySpringSession]]] = ContextVar("session", default=None) @classmethod - def get_or_create_session(cls) -> Session: + def get_or_create_session(cls) -> PySpringSession: optional_session = cls._session.get() if optional_session is None: session = PySpringModel.create_session() From 8e3efc1e7519b3af6d367ccc2f84acd49d83c511 Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 20 Jun 2025 02:16:48 +0800 Subject: [PATCH 09/15] test: Add comprehensive tests for @Transactional decorator functionality --- tests/test_transactional_decorator.py | 356 ++++++++++++++++++++++++++ 1 file changed, 356 insertions(+) create mode 100644 tests/test_transactional_decorator.py diff --git a/tests/test_transactional_decorator.py b/tests/test_transactional_decorator.py new file mode 100644 index 0000000..e219cf4 --- /dev/null +++ b/tests/test_transactional_decorator.py @@ -0,0 +1,356 @@ +import pytest +from unittest.mock import patch +from sqlalchemy import create_engine, text +from sqlmodel import Field, SQLModel + +from py_spring_model import PySpringModel +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional + + +class TransactionalTestUser(PySpringModel, table=True): + """Test model for transactional operations""" + id: int = Field(default=None, primary_key=True) + name: str + email: str + age: int = Field(default=0) + + +class TestTransactionalDecorator: + """Test suite for the @Transactional decorator""" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + """Set up test environment with in-memory SQLite database""" + self.engine = create_engine("sqlite:///:memory:", echo=False) + PySpringModel.set_engine(self.engine) + PySpringModel.set_metadata(SQLModel.metadata) + PySpringModel.set_models([TransactionalTestUser]) + + # Clear any existing session + SessionContextHolder.clear_session() + + SQLModel.metadata.create_all(self.engine) + + def teardown_method(self): + """Tear down test environment""" + SQLModel.metadata.drop_all(self.engine) + SessionContextHolder.clear_session() + PySpringModel._engine = None + PySpringModel._metadata = None + PySpringModel._connection = None + + def test_single_transactional_success(self): + """Test that a single @Transactional function commits successfully""" + + @Transactional + def create_user(): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name="John Doe", email="john@example.com", age=30) + session.add(user) + session.flush() # To get the ID + return user + + # Execute the transactional function + result = create_user() + + # Verify the user was created and committed + assert result.name == "John Doe" + assert result.email == "john@example.com" + assert result.age == 30 + + # Verify session is cleared after transaction + assert not SessionContextHolder.has_session() + + # Verify data persisted to database + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall() + assert len(users) == 1 + assert users[0].name == "John Doe" + + def test_single_transactional_rollback(self): + """Test that a single @Transactional function rolls back on exception""" + + @Transactional + def create_user_with_error(): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name="Jane Doe", email="jane@example.com", age=25) + session.add(user) + session.flush() + raise ValueError("Simulated error") + + # Execute the transactional function and expect exception + with pytest.raises(ValueError, match="Simulated error"): + create_user_with_error() + + # Verify session is cleared after rollback + assert not SessionContextHolder.has_session() + + # Verify no data persisted to database + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall() + assert len(users) == 0 + + def test_nested_transactional_success(self): + """Test that nested @Transactional functions share the same session and commit at top level""" + + @Transactional + def create_user(name: str, email: str): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name=name, email=email, age=30) + session.add(user) + session.flush() + return user + + @Transactional + def update_user_age(user_id: int, new_age: int): + session = SessionContextHolder.get_or_create_session() + session.execute(text(f"UPDATE transactionaltestuser SET age = {new_age} WHERE id = {user_id}")) + + @Transactional + def create_and_update_user(): + # This should share the same session with nested calls + user = create_user("Alice Smith", "alice@example.com") + update_user_age(user.id, 35) + return user + + # Execute the outer transactional function + result = create_and_update_user() + + # Verify session is cleared after transaction + assert not SessionContextHolder.has_session() + + # Verify both operations were committed + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall() + assert len(users) == 1 + assert users[0].name == "Alice Smith" + assert users[0].age == 35 # Updated by nested function + + def test_nested_transactional_rollback_from_inner(self): + """Test that exception in nested @Transactional causes rollback at top level""" + + @Transactional + def create_user(name: str, email: str): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name=name, email=email, age=30) + session.add(user) + session.flush() + return user + + @Transactional + def update_user_with_error(user_id: int): + session = SessionContextHolder.get_or_create_session() + session.execute(text(f"UPDATE transactionaltestuser SET age = 40 WHERE id = {user_id}")) + raise RuntimeError("Update failed") + + @Transactional + def create_and_update_user_with_error(): + user = create_user("Bob Johnson", "bob@example.com") + update_user_with_error(user.id) # This will raise an exception + return user + + # Execute and expect exception + with pytest.raises(RuntimeError, match="Update failed"): + create_and_update_user_with_error() + + # Verify session is cleared after rollback + assert not SessionContextHolder.has_session() + + # Verify no data persisted (everything rolled back) + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall() + assert len(users) == 0 + + def test_nested_transactional_rollback_from_outer(self): + """Test that exception in outer @Transactional causes rollback after nested calls""" + + @Transactional + def create_user(name: str, email: str): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name=name, email=email, age=30) + session.add(user) + session.flush() + return user + + @Transactional + def update_user_age(user_id: int, new_age: int): + session = SessionContextHolder.get_or_create_session() + session.execute(text(f"UPDATE transactionaltestuser SET age = {new_age} WHERE id = {user_id}")) + + @Transactional + def create_update_and_fail(): + user = create_user("Charlie Brown", "charlie@example.com") + update_user_age(user.id, 45) + # Both nested operations succeeded, but outer fails + raise Exception("Outer operation failed") + + # Execute and expect exception + with pytest.raises(Exception, match="Outer operation failed"): + create_update_and_fail() + + # Verify session is cleared after rollback + assert not SessionContextHolder.has_session() + + # Verify no data persisted (everything rolled back) + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall() + assert len(users) == 0 + + def test_session_sharing_across_nested_transactions(self): + """Test that nested @Transactional functions share the same session instance""" + + captured_sessions = [] + + @Transactional + def inner_function(): + session = SessionContextHolder.get_or_create_session() + captured_sessions.append(session) + user = TransactionalTestUser(name="Inner User", email="inner@example.com", age=25) + session.add(user) + session.flush() + return session + + @Transactional + def middle_function(): + session = SessionContextHolder.get_or_create_session() + captured_sessions.append(session) + inner_session = inner_function() + return session, inner_session + + @Transactional + def outer_function(): + session = SessionContextHolder.get_or_create_session() + captured_sessions.append(session) + middle_session, inner_session = middle_function() + return session, middle_session, inner_session + + # Execute nested transactions + outer_session, middle_session, inner_session = outer_function() + + # Verify all functions used the same session instance + assert len(captured_sessions) == 3 + assert captured_sessions[0] is captured_sessions[1] # outer and middle + assert captured_sessions[1] is captured_sessions[2] # middle and inner + assert outer_session is middle_session is inner_session + + # Verify session is cleared after transaction + assert not SessionContextHolder.has_session() + + def test_transactional_session_context_isolation(self): + """Test that @Transactional properly isolates session context""" + + @Transactional + def first_transaction(): + session = SessionContextHolder.get_or_create_session() + user1 = TransactionalTestUser(name="User 1", email="user1@example.com", age=30) + session.add(user1) + session.flush() + return user1.id + + @Transactional + def second_transaction(): + session = SessionContextHolder.get_or_create_session() + user2 = TransactionalTestUser(name="User 2", email="user2@example.com", age=35) + session.add(user2) + session.flush() + return user2.id + + # Execute separate transactions + first_transaction() + second_transaction() + + # Verify sessions were properly isolated + assert not SessionContextHolder.has_session() + + # Verify both transactions committed separately + with PySpringModel.create_managed_session() as session: + users = session.execute(text("SELECT * FROM transactionaltestuser ORDER BY id")).fetchall() + assert len(users) == 2 + assert users[0].name == "User 1" + assert users[1].name == "User 2" + + def test_transactional_preserves_function_metadata(self): + """Test that @Transactional preserves original function metadata""" + + @Transactional + def documented_function(param1: str, param2: int = 10) -> str: + """This is a documented function with parameters.""" + return f"{param1}_{param2}" + + # Verify function metadata is preserved + assert documented_function.__name__ == "documented_function" + assert documented_function.__doc__ is not None and "documented function" in documented_function.__doc__ + + # Verify function still works correctly + result = documented_function("test", 20) + assert result == "test_20" + + def test_transactional_commit_rollback_behavior(self): + """Test the core commit/rollback behavior of nested transactions""" + + commit_calls = [] + rollback_calls = [] + + # Mock session to track commit/rollback calls + original_session_class = PySpringModel.create_session + + def mock_create_session(): + session = original_session_class() + original_commit = session.commit + original_rollback = session.rollback + + def mock_commit(): + commit_calls.append("commit") + return original_commit() + + def mock_rollback(): + rollback_calls.append("rollback") + return original_rollback() + + session.commit = mock_commit + session.rollback = mock_rollback + return session + + @Transactional + def inner_operation(): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name="Test", email="test@example.com", age=30) + session.add(user) + session.flush() + + @Transactional + def outer_operation(): + inner_operation() + + # Test successful nested transaction + with patch.object(PySpringModel, 'create_session', side_effect=mock_create_session): + outer_operation() + + # Only the outermost transaction should commit + assert len(commit_calls) == 1 + assert len(rollback_calls) == 0 + + # Reset counters + commit_calls.clear() + rollback_calls.clear() + + @Transactional + def inner_operation_with_error(): + session = SessionContextHolder.get_or_create_session() + user = TransactionalTestUser(name="Test2", email="test2@example.com", age=25) + session.add(user) + session.flush() + raise ValueError("Inner error") + + @Transactional + def outer_operation_with_error(): + inner_operation_with_error() + + # Test failed nested transaction + with patch.object(PySpringModel, 'create_session', side_effect=mock_create_session): + with pytest.raises(ValueError): + outer_operation_with_error() + + # Only the outermost transaction should rollback + assert len(commit_calls) == 0 + assert len(rollback_calls) == 1 \ No newline at end of file From b22e3db0bccc42509a27e643256a4504c1257918 Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 20 Jun 2025 17:36:00 +0800 Subject: [PATCH 10/15] feat: Implement explicit session depth tracking in SessionContextHolder - Added TransactionalDepth enum to define transaction levels. - Enhanced SessionContextHolder with methods to manage session depth. - Updated @Transactional decorator to utilize session depth for commit/rollback operations. - Introduced tests for session depth tracking and behavior in nested transactions. --- .../core/session_context_holder.py | 62 ++++++++-- tests/test_session_depth.py | 112 ++++++++++++++++++ 2 files changed, 165 insertions(+), 9 deletions(-) create mode 100644 tests/test_session_depth.py diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py index b990872..753868d 100644 --- a/py_spring_model/core/session_context_holder.py +++ b/py_spring_model/core/session_context_holder.py @@ -1,4 +1,5 @@ from contextvars import ContextVar +from enum import IntEnum from functools import wraps from typing import Any, Callable, ClassVar, Optional @@ -6,6 +7,10 @@ from py_spring_model.core.model import PySpringModel +class TransactionalDepth(IntEnum): + OUTERMOST = 1 + ON_EXIT = 0 + def Transactional(func: Callable[..., Any]) -> Callable[..., Any]: """ Decorator for managing database transactions in a nested-safe manner. @@ -45,30 +50,34 @@ def update_account(): """ @wraps(func) def wrapper(*args, **kwargs): - is_outermost_transaction = not SessionContextHolder.has_session() + # Increment session depth and get session + session_depth = SessionContextHolder.enter_session() session = SessionContextHolder.get_or_create_session() try: result = func(*args, **kwargs) - if is_outermost_transaction: + # Only commit at the outermost level (session_depth == 1) + if session_depth == TransactionalDepth.OUTERMOST.value: session.commit() return result except Exception as error: - if is_outermost_transaction: + # Only rollback at the outermost level (session_depth == 1) + if session_depth == TransactionalDepth.OUTERMOST.value: session.rollback() raise error finally: - if is_outermost_transaction: - SessionContextHolder.clear_session() + # Decrement depth and clean up session if needed + SessionContextHolder.exit_session() return wrapper class SessionContextHolder: """ - A context holder for the session. + A context holder for the session with explicit depth tracking. This is used to store the session in a context variable so that it can be accessed by the query service. - This is useful for the query service to be able to access the session without having to pass it in as an argument. - This is also useful for the query service to be able to access the session without having to pass it in as an argument. + The depth counter ensures that only the outermost transaction manages commit/rollback operations. """ _session: ClassVar[ContextVar[Optional[PySpringSession]]] = ContextVar("session", default=None) + _session_depth: ClassVar[ContextVar[int]] = ContextVar("session_depth", default=0) + @classmethod def get_or_create_session(cls) -> PySpringSession: optional_session = cls._session.get() @@ -82,9 +91,44 @@ def get_or_create_session(cls) -> PySpringSession: def has_session(cls) -> bool: return cls._session.get() is not None + @classmethod + def get_session_depth(cls) -> int: + """Get the current session depth.""" + return cls._session_depth.get() + + @classmethod + def enter_session(cls) -> int: + """ + Enter a new session context and increment the depth counter. + Returns the new depth level. + """ + current_depth = cls._session_depth.get() + new_depth = current_depth + 1 + cls._session_depth.set(new_depth) + return new_depth + + @classmethod + def exit_session(cls) -> int: + """ + Exit the current session context and decrement the depth counter. + If depth reaches 0, clear the session. + Returns the new depth level. + """ + current_depth = cls._session_depth.get() + new_depth = max(0, current_depth - 1) # Prevent negative depth + cls._session_depth.set(new_depth) + + # Clear session only when depth reaches 0 (outermost level) + if new_depth == 0: + cls.clear_session() + + return new_depth + @classmethod def clear_session(cls): + """Clear the session and reset depth to 0.""" session = cls._session.get() if session is not None: session.close() - cls._session.set(None) \ No newline at end of file + cls._session.set(None) + cls._session_depth.set(TransactionalDepth.ON_EXIT.value) \ No newline at end of file diff --git a/tests/test_session_depth.py b/tests/test_session_depth.py new file mode 100644 index 0000000..bb4c687 --- /dev/null +++ b/tests/test_session_depth.py @@ -0,0 +1,112 @@ +import pytest +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional +from py_spring_model.core.model import PySpringModel + + +class TestSessionDepth: + """Test the explicit session depth tracking functionality""" + + def setup_method(self): + """Clean up any existing sessions before each test""" + SessionContextHolder.clear_session() + + def teardown_method(self): + """Clean up after each test""" + SessionContextHolder.clear_session() + + def test_session_depth_starts_at_zero(self): + """Test that session depth starts at 0""" + assert SessionContextHolder.get_session_depth() == 0 + + def test_session_depth_increments_and_decrements(self): + """Test that session depth properly increments and decrements""" + # Initially 0 + assert SessionContextHolder.get_session_depth() == 0 + + # Enter first level + depth1 = SessionContextHolder.enter_session() + assert depth1 == 1 + assert SessionContextHolder.get_session_depth() == 1 + + # Enter second level + depth2 = SessionContextHolder.enter_session() + assert depth2 == 2 + assert SessionContextHolder.get_session_depth() == 2 + + # Exit second level + depth_after_exit1 = SessionContextHolder.exit_session() + assert depth_after_exit1 == 1 + assert SessionContextHolder.get_session_depth() == 1 + + # Exit first level + depth_after_exit2 = SessionContextHolder.exit_session() + assert depth_after_exit2 == 0 + assert SessionContextHolder.get_session_depth() == 0 + + def test_session_cleared_only_at_outermost_level(self): + """Test that session is only cleared when depth reaches 0""" + # Enter first level and create session + SessionContextHolder.enter_session() + session = SessionContextHolder.get_or_create_session() + assert SessionContextHolder.has_session() + + # Enter second level - session should still exist + SessionContextHolder.enter_session() + assert SessionContextHolder.has_session() + assert SessionContextHolder.get_session_depth() == 2 + + # Exit second level - session should still exist + SessionContextHolder.exit_session() + assert SessionContextHolder.has_session() + assert SessionContextHolder.get_session_depth() == 1 + + # Exit first level - session should be cleared + SessionContextHolder.exit_session() + assert not SessionContextHolder.has_session() + assert SessionContextHolder.get_session_depth() == 0 + + def test_clear_session_resets_depth(self): + """Test that clear_session() resets the depth to 0""" + SessionContextHolder.enter_session() + SessionContextHolder.enter_session() + assert SessionContextHolder.get_session_depth() == 2 + + SessionContextHolder.clear_session() + assert SessionContextHolder.get_session_depth() == 0 + + @pytest.mark.parametrize("nesting_levels", [1, 2, 3, 5]) + def test_transactional_depth_tracking(self, nesting_levels): + """Test that @Transactional properly tracks depth at various nesting levels""" + depth_records = [] + + def create_nested_function(level: int): + @Transactional + def nested_func(): + current_depth = SessionContextHolder.get_session_depth() + depth_records.append(current_depth) + if level > 1: + create_nested_function(level - 1)() + return nested_func + + # Create and execute nested function + outermost_func = create_nested_function(nesting_levels) + outermost_func() + + # Verify depth progression + assert len(depth_records) == nesting_levels + for i, recorded_depth in enumerate(depth_records): + expected_depth = i + 1 + assert recorded_depth == expected_depth, f"At level {i+1}, expected depth {expected_depth}, got {recorded_depth}" + + # Verify session is cleaned up + assert SessionContextHolder.get_session_depth() == 0 + assert not SessionContextHolder.has_session() + + def test_depth_prevents_negative_values(self): + """Test that depth counter prevents going below 0""" + assert SessionContextHolder.get_session_depth() == 0 + + # Try to exit when already at 0 + new_depth = SessionContextHolder.exit_session() + assert new_depth == 0 + assert SessionContextHolder.get_session_depth() == 0 \ No newline at end of file From b95fe273b61b9ad83abd0eb82cd86c7f1554c73d Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 20 Jun 2025 17:52:51 +0800 Subject: [PATCH 11/15] fix: Update session clearing condition in SessionContextHolder - Changed the condition for clearing the session from checking if the depth is 0 to using TransactionalDepth.ON_EXIT.value for better clarity and alignment with transaction management. --- py_spring_model/core/session_context_holder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py index 753868d..ed3643d 100644 --- a/py_spring_model/core/session_context_holder.py +++ b/py_spring_model/core/session_context_holder.py @@ -119,7 +119,7 @@ def exit_session(cls) -> int: cls._session_depth.set(new_depth) # Clear session only when depth reaches 0 (outermost level) - if new_depth == 0: + if new_depth == TransactionalDepth.ON_EXIT.value: cls.clear_session() return new_depth From 86e35625147d6d08b70e9d7432fb48b4959bf242 Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 20 Jun 2025 20:48:14 +0800 Subject: [PATCH 12/15] feat: Enhance CRUD repository with @Transactional decorator for improved transaction management - Added @Transactional decorator to multiple CRUD methods in CrudRepository to ensure proper transaction handling. - Updated commit logic in PySpringSession to warn against committing managed transactions. - Refactored clone method in PySpringModel to return the correct type. - Improved logging during session commit in PySpringModel for better clarity. - Introduced a method in SessionContextHolder to check if a transaction is managed. --- py_spring_model/core/model.py | 9 ++------- py_spring_model/core/py_spring_session.py | 10 ++++++++++ py_spring_model/core/session_context_holder.py | 9 ++++++--- py_spring_model/repository/crud_repository.py | 10 +++++++++- tests/test_crud_repository.py | 15 ++++++++++++++- 5 files changed, 41 insertions(+), 12 deletions(-) diff --git a/py_spring_model/core/model.py b/py_spring_model/core/model.py index 9f7987f..32ed6fb 100644 --- a/py_spring_model/core/model.py +++ b/py_spring_model/core/model.py @@ -1,5 +1,4 @@ import contextlib -from typing_extensions import Self from typing import ClassVar, Iterator, Optional, Type from loguru import logger from sqlalchemy import Engine, MetaData @@ -81,7 +80,7 @@ def get_model_lookup(cls) -> dict[str, type["PySpringModel"]]: raise ValueError("[MODEL_LOOKUP NOT SET] Model lookup is not set") return {str(_model.__tablename__): _model for _model in cls._models} - def clone(self) -> Self: + def clone(self) -> "PySpringModel": return self.model_validate_json(self.model_dump_json()) @classmethod @@ -105,11 +104,7 @@ def create_managed_session(cls, should_commit: bool = True) -> Iterator[PySpring logger.debug("[MANAGED SESSION COMMIT] Session committing...") if should_commit: session.commit() - logger.debug( - "[MANAGED SESSION COMMIT] Session committed, refreshing instances..." - ) - session.refresh_current_session_instances() - logger.success("[MANAGED SESSION COMMIT] Session committed.") + logger.success("[MANAGED SESSION COMMIT] Session committed.") except Exception as error: logger.error(error) logger.error("[MANAGED SESSION ROLLBACK] Session rolling back...") diff --git a/py_spring_model/core/py_spring_session.py b/py_spring_model/core/py_spring_session.py index 7a97dfe..7080445 100644 --- a/py_spring_model/core/py_spring_session.py +++ b/py_spring_model/core/py_spring_session.py @@ -1,5 +1,6 @@ from typing import Iterable +from loguru import logger from sqlmodel import Session, SQLModel @@ -33,3 +34,12 @@ def add_all(self, instances: Iterable[SQLModel]) -> None: def refresh_current_session_instances(self) -> None: for instance in self.current_session_instance: self.refresh(instance) + + def commit(self) -> None: + # Import here to avoid circular import + from py_spring_model.core.session_context_holder import SessionContextHolder + if SessionContextHolder.is_transaction_managed(): + logger.warning("Commiting a transaction that is currently being managed by the outermost transaction is strongly discouraged...") + return + super().commit() + self.refresh_current_session_instances() \ No newline at end of file diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py index ed3643d..daf8c97 100644 --- a/py_spring_model/core/session_context_holder.py +++ b/py_spring_model/core/session_context_holder.py @@ -3,9 +3,8 @@ from functools import wraps from typing import Any, Callable, ClassVar, Optional -from py_spring_model.core.py_spring_session import PySpringSession - from py_spring_model.core.model import PySpringModel +from py_spring_model.core.py_spring_session import PySpringSession class TransactionalDepth(IntEnum): OUTERMOST = 1 @@ -131,4 +130,8 @@ def clear_session(cls): if session is not None: session.close() cls._session.set(None) - cls._session_depth.set(TransactionalDepth.ON_EXIT.value) \ No newline at end of file + cls._session_depth.set(TransactionalDepth.ON_EXIT.value) + + @classmethod + def is_transaction_managed(cls) -> bool: + return cls._session_depth.get() > TransactionalDepth.OUTERMOST.value \ No newline at end of file diff --git a/py_spring_model/repository/crud_repository.py b/py_spring_model/repository/crud_repository.py index c439934..0b0ed64 100644 --- a/py_spring_model/repository/crud_repository.py +++ b/py_spring_model/repository/crud_repository.py @@ -59,6 +59,7 @@ def __init__(self) -> None: def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]: return get_args(tp=cls.__mro__[0].__orig_bases__[0]) + @Transactional def _find_by_statement( self, statement: Union[Select, SelectOfScalar], @@ -67,7 +68,7 @@ def _find_by_statement( return session.exec(statement).first() - + @Transactional def _find_by_query( self, query_by: dict[str, Any], @@ -76,6 +77,8 @@ def _find_by_query( statement = select(self.model_class).filter_by(**query_by) return session.exec(statement).first() + + @Transactional def _find_all_by_query( self, query_by: dict[str, Any], @@ -84,6 +87,7 @@ def _find_all_by_query( statement = select(self.model_class).filter_by(**query_by) return session, list(session.exec(statement).fetchall()) + @Transactional def _find_all_by_statement( self, statement: Union[Select, SelectOfScalar], @@ -91,6 +95,7 @@ def _find_all_by_statement( session = SessionContextHolder.get_or_create_session() return list(session.exec(statement).fetchall()) + @Transactional def find_by_id(self, id: ID) -> Optional[T]: session = SessionContextHolder.get_or_create_session() statement = select(self.model_class).where(self.model_class.id == id) # type: ignore @@ -99,11 +104,14 @@ def find_by_id(self, id: ID) -> Optional[T]: return return optional_entity + + @Transactional def find_all_by_ids(self, ids: list[ID]) -> list[T]: session = SessionContextHolder.get_or_create_session() statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore return [entity for entity in session.exec(statement).all()] # type: ignore + @Transactional def find_all(self) -> list[T]: session = SessionContextHolder.get_or_create_session() statement = select(self.model_class) # type: ignore diff --git a/tests/test_crud_repository.py b/tests/test_crud_repository.py index df7854b..d3dc628 100644 --- a/tests/test_crud_repository.py +++ b/tests/test_crud_repository.py @@ -135,4 +135,17 @@ def test_upsert_for_new_user(self, user_repository: UserRepository): new_user = user_repository.find_by_id(1) assert new_user is not None assert new_user.name == "John Doe" - assert new_user.email == "john@example.com" \ No newline at end of file + assert new_user.email == "john@example.com" + + def test_update_user(self, user_repository: UserRepository): + self.create_test_user(user_repository) + user = user_repository.find_by_id(1) + assert user is not None + user.name = "William Chen" + user.email = "william.chen@example.com" + user_repository.save(user) + updated_user = user_repository.find_by_id(1) + assert updated_user is not None + assert updated_user.name == "William Chen" + assert updated_user.email == "william.chen@example.com" + \ No newline at end of file From 9d31839c0ff1fbf783e9956fb4cdae8c4547f825 Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 20 Jun 2025 22:14:37 +0800 Subject: [PATCH 13/15] feat: Enhance type safety in PySpringModel with TypeVar - Introduced TypeVar for PySpringModel to improve type hinting in the clone method. - Updated the clone method signature to return the correct type, enhancing clarity and type safety. --- py_spring_model/core/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/py_spring_model/core/model.py b/py_spring_model/core/model.py index 32ed6fb..d5166d2 100644 --- a/py_spring_model/core/model.py +++ b/py_spring_model/core/model.py @@ -1,5 +1,5 @@ import contextlib -from typing import ClassVar, Iterator, Optional, Type +from typing import ClassVar, Iterator, Optional, Type, TypeVar from loguru import logger from sqlalchemy import Engine, MetaData from sqlalchemy.engine.base import Connection @@ -7,6 +7,8 @@ from py_spring_model.core.py_spring_session import PySpringSession +T = TypeVar("T", bound="PySpringModel") + class PySpringModel(SQLModel): """ @@ -80,7 +82,7 @@ def get_model_lookup(cls) -> dict[str, type["PySpringModel"]]: raise ValueError("[MODEL_LOOKUP NOT SET] Model lookup is not set") return {str(_model.__tablename__): _model for _model in cls._models} - def clone(self) -> "PySpringModel": + def clone(self: T) -> T: return self.model_validate_json(self.model_dump_json()) @classmethod From bc2e6a0455ba6b2bcfad2a696d74f06a646114af Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 20 Jun 2025 22:43:23 +0800 Subject: [PATCH 14/15] feat: Improve type safety and fix variable naming in CRUD repository and session context --- py_spring_model/core/session_context_holder.py | 10 +++++++--- py_spring_model/repository/crud_repository.py | 6 +++--- tests/test_crud_repository.py | 11 ++++++++++- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/py_spring_model/core/session_context_holder.py b/py_spring_model/core/session_context_holder.py index daf8c97..1f7eafc 100644 --- a/py_spring_model/core/session_context_holder.py +++ b/py_spring_model/core/session_context_holder.py @@ -1,7 +1,7 @@ from contextvars import ContextVar from enum import IntEnum from functools import wraps -from typing import Any, Callable, ClassVar, Optional +from typing import Any, Callable, ClassVar, Optional, ParamSpec, TypeVar from py_spring_model.core.model import PySpringModel from py_spring_model.core.py_spring_session import PySpringSession @@ -10,7 +10,11 @@ class TransactionalDepth(IntEnum): OUTERMOST = 1 ON_EXIT = 0 -def Transactional(func: Callable[..., Any]) -> Callable[..., Any]: + +P = ParamSpec("P") +RT = TypeVar("RT") + +def Transactional(func: Callable[P, RT]) -> Callable[P, RT]: """ Decorator for managing database transactions in a nested-safe manner. @@ -48,7 +52,7 @@ def update_account(): the whole transaction will be rolled back. """ @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> RT: # Increment session depth and get session session_depth = SessionContextHolder.enter_session() session = SessionContextHolder.get_or_create_session() diff --git a/py_spring_model/repository/crud_repository.py b/py_spring_model/repository/crud_repository.py index 0b0ed64..131492f 100644 --- a/py_spring_model/repository/crud_repository.py +++ b/py_spring_model/repository/crud_repository.py @@ -135,10 +135,10 @@ def save_all( @Transactional def delete(self, entity: T) -> bool: session = SessionContextHolder.get_or_create_session() - optional_intance = self._find_by_query(entity.model_dump()) - if optional_intance is None: + optional_instance = self._find_by_query(entity.model_dump()) + if optional_instance is None: return False - session.delete(optional_intance) + session.delete(optional_instance) return True @Transactional diff --git a/tests/test_crud_repository.py b/tests/test_crud_repository.py index d3dc628..17d7234 100644 --- a/tests/test_crud_repository.py +++ b/tests/test_crud_repository.py @@ -148,4 +148,13 @@ def test_update_user(self, user_repository: UserRepository): assert updated_user is not None assert updated_user.name == "William Chen" assert updated_user.email == "william.chen@example.com" - \ No newline at end of file + + def test_delete_user_with_user_found(self, user_repository: UserRepository): + self.create_test_user(user_repository) + user = user_repository.find_by_id(1) + assert user is not None + assert user_repository.delete(user) + assert user_repository.find_by_id(1) is None + + def test_delete_user_with_user_not_found(self, user_repository: UserRepository): + assert user_repository.delete(User(id=1, name="John Doe", email="john@example.com")) is False \ No newline at end of file From 4f59f861c64757320c6a373cd90c3dc9f19093e5 Mon Sep 17 00:00:00 2001 From: William Chen Date: Thu, 17 Jul 2025 20:03:03 +0800 Subject: [PATCH 15/15] feat: Add versioning and export Query in PySpringModel - Introduced __version__ attribute to specify the current version of the package. - Updated __all__ to include Query for better module accessibility. --- py_spring_model/__init__.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/py_spring_model/__init__.py b/py_spring_model/__init__.py index b14de83..5571a52 100644 --- a/py_spring_model/__init__.py +++ b/py_spring_model/__init__.py @@ -4,4 +4,18 @@ from py_spring_model.repository.crud_repository import CrudRepository from py_spring_model.repository.repository_base import RepositoryBase from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.crud_repository_implementation_service import SkipAutoImplmentation -from py_spring_model.py_spring_model_rest.service.query_service.query import Query \ No newline at end of file +from py_spring_model.py_spring_model_rest.service.query_service.query import Query + + +__all__ = [ + "PySpringModel", + "Field", + "SessionContextHolder", + "provide_py_spring_model", + "CrudRepository", + "RepositoryBase", + "SkipAutoImplmentation", + "Query", +] + +__version__ = "0.1.0" \ No newline at end of file