Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
a7f4a65
Implement Session Management and Transactional Decorator for CRUD Ope…
William-W-Chen Jun 19, 2025
708906d
Refactor CRUD Operations to Utilize SessionContextHolder for Session …
William-W-Chen Jun 19, 2025
6e2631e
Remove commented-out `PySpringModelRestController` from provider conf…
William-W-Chen Jun 19, 2025
4699eaf
feat: Enhanced @Transactional decorator with smart session management
William-W-Chen Jun 19, 2025
c79b134
refactor: Rename variable `is_outermost` to `is_outermost_transaction…
William-W-Chen Jun 19, 2025
7ce0e35
feat: Introduce SessionController and session middleware for improved…
William-W-Chen Jun 19, 2025
7c7cc55
refactor: Clean up documentation in @Transactional decorator
William-W-Chen Jun 19, 2025
4e7b847
fix: Update SessionContextHolder to use PySpringSession
William-W-Chen Jun 19, 2025
8e3efc1
test: Add comprehensive tests for @Transactional decorator functionality
William-W-Chen Jun 19, 2025
b22e3db
feat: Implement explicit session depth tracking in SessionContextHolder
William-W-Chen Jun 20, 2025
b95fe27
fix: Update session clearing condition in SessionContextHolder
William-W-Chen Jun 20, 2025
86e3562
feat: Enhance CRUD repository with @Transactional decorator for impro…
William-W-Chen Jun 20, 2025
9d31839
feat: Enhance type safety in PySpringModel with TypeVar
William-W-Chen Jun 20, 2025
bc2e6a0
feat: Improve type safety and fix variable naming in CRUD repository …
William-W-Chen Jun 20, 2025
4f59f86
feat: Add versioning and export Query in PySpringModel
William-W-Chen Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion py_spring_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
from py_spring_model.core.model import PySpringModel, Field
from py_spring_model.core.session_context_holder import SessionContextHolder
from py_spring_model.py_spring_model_provider import provide_py_spring_model
from py_spring_model.repository.crud_repository import CrudRepository
from py_spring_model.repository.repository_base import RepositoryBase
from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.crud_repository_implementation_service import SkipAutoImplmentation
from py_spring_model.py_spring_model_rest.service.query_service.query import Query
from py_spring_model.py_spring_model_rest.service.query_service.query import Query


__all__ = [
"PySpringModel",
"Field",
"SessionContextHolder",
"provide_py_spring_model",
"CrudRepository",
"RepositoryBase",
"SkipAutoImplmentation",
"Query",
]

__version__ = "0.1.0"
13 changes: 5 additions & 8 deletions py_spring_model/core/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import contextlib
from typing_extensions import Self
from typing import ClassVar, Iterator, Optional, Type
from typing import ClassVar, Iterator, Optional, Type, TypeVar
from loguru import logger
from sqlalchemy import Engine, MetaData
from sqlalchemy.engine.base import Connection
from sqlmodel import SQLModel, Field

from py_spring_model.core.py_spring_session import PySpringSession

T = TypeVar("T", bound="PySpringModel")


class PySpringModel(SQLModel):
"""
Expand Down Expand Up @@ -81,7 +82,7 @@ def get_model_lookup(cls) -> dict[str, type["PySpringModel"]]:
raise ValueError("[MODEL_LOOKUP NOT SET] Model lookup is not set")
return {str(_model.__tablename__): _model for _model in cls._models}

def clone(self) -> Self:
def clone(self: T) -> T:
return self.model_validate_json(self.model_dump_json())

@classmethod
Expand All @@ -105,11 +106,7 @@ def create_managed_session(cls, should_commit: bool = True) -> Iterator[PySpring
logger.debug("[MANAGED SESSION COMMIT] Session committing...")
if should_commit:
session.commit()
logger.debug(
"[MANAGED SESSION COMMIT] Session committed, refreshing instances..."
)
session.refresh_current_session_instances()
logger.success("[MANAGED SESSION COMMIT] Session committed.")
logger.success("[MANAGED SESSION COMMIT] Session committed.")
except Exception as error:
logger.error(error)
logger.error("[MANAGED SESSION ROLLBACK] Session rolling back...")
Expand Down
10 changes: 10 additions & 0 deletions py_spring_model/core/py_spring_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Iterable

from loguru import logger
from sqlmodel import Session, SQLModel


Expand Down Expand Up @@ -33,3 +34,12 @@ def add_all(self, instances: Iterable[SQLModel]) -> None:
def refresh_current_session_instances(self) -> None:
for instance in self.current_session_instance:
self.refresh(instance)

def commit(self) -> None:
# Import here to avoid circular import
from py_spring_model.core.session_context_holder import SessionContextHolder
if SessionContextHolder.is_transaction_managed():
logger.warning("Commiting a transaction that is currently being managed by the outermost transaction is strongly discouraged...")
return
super().commit()
self.refresh_current_session_instances()
141 changes: 141 additions & 0 deletions py_spring_model/core/session_context_holder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from contextvars import ContextVar
from enum import IntEnum
from functools import wraps
from typing import Any, Callable, ClassVar, Optional, ParamSpec, TypeVar

from py_spring_model.core.model import PySpringModel
from py_spring_model.core.py_spring_session import PySpringSession

class TransactionalDepth(IntEnum):
OUTERMOST = 1
ON_EXIT = 0


P = ParamSpec("P")
RT = TypeVar("RT")

def Transactional(func: Callable[P, RT]) -> Callable[P, RT]:
"""
Decorator for managing database transactions in a nested-safe manner.

This decorator ensures that:
- A new session is created only if there is no active session (i.e., outermost transaction).
- The session is committed, rolled back, and closed only by the outermost function.
- Nested transactional functions share the same session and do not interfere with the commit/rollback behavior.

Behavior Summary:
- If this function is the outermost @Transactional in the call stack:
- A new session is created.
- On success, the session is committed.
- On failure, the session is rolled back.
- The session is closed after execution.
- If this function is called within an existing transaction:
- The existing session is reused.
- No commit, rollback, or close is performed (delegated to the outermost function).

Example:
@Transactional
def outer_operation():
create_user()
update_account()

@Transactional
def create_user():
db.session.add(User(...)) # Uses same session as outer_operation

@Transactional
def update_account():
db.session.add(Account(...)) # Uses same session as outer_operation

Only outer_operation will commit or rollback.
If create_user() or update_account() raises an exception,
the whole transaction will be rolled back.
"""
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> RT:
# Increment session depth and get session
session_depth = SessionContextHolder.enter_session()
session = SessionContextHolder.get_or_create_session()
try:
result = func(*args, **kwargs)
# Only commit at the outermost level (session_depth == 1)
if session_depth == TransactionalDepth.OUTERMOST.value:
session.commit()
return result
except Exception as error:
# Only rollback at the outermost level (session_depth == 1)
if session_depth == TransactionalDepth.OUTERMOST.value:
session.rollback()
raise error
finally:
# Decrement depth and clean up session if needed
SessionContextHolder.exit_session()
return wrapper

class SessionContextHolder:
"""
A context holder for the session with explicit depth tracking.
This is used to store the session in a context variable so that it can be accessed by the query service.
The depth counter ensures that only the outermost transaction manages commit/rollback operations.
"""
_session: ClassVar[ContextVar[Optional[PySpringSession]]] = ContextVar("session", default=None)
_session_depth: ClassVar[ContextVar[int]] = ContextVar("session_depth", default=0)

@classmethod
def get_or_create_session(cls) -> PySpringSession:
optional_session = cls._session.get()
if optional_session is None:
session = PySpringModel.create_session()
cls._session.set(session)
return session
return optional_session

@classmethod
def has_session(cls) -> bool:
return cls._session.get() is not None

@classmethod
def get_session_depth(cls) -> int:
"""Get the current session depth."""
return cls._session_depth.get()

@classmethod
def enter_session(cls) -> int:
"""
Enter a new session context and increment the depth counter.
Returns the new depth level.
"""
current_depth = cls._session_depth.get()
new_depth = current_depth + 1
cls._session_depth.set(new_depth)
return new_depth

@classmethod
def exit_session(cls) -> int:
"""
Exit the current session context and decrement the depth counter.
If depth reaches 0, clear the session.
Returns the new depth level.
"""
current_depth = cls._session_depth.get()
new_depth = max(0, current_depth - 1) # Prevent negative depth
cls._session_depth.set(new_depth)

# Clear session only when depth reaches 0 (outermost level)
if new_depth == TransactionalDepth.ON_EXIT.value:
cls.clear_session()

return new_depth

@classmethod
def clear_session(cls):
"""Clear the session and reset depth to 0."""
session = cls._session.get()
if session is not None:
session.close()
cls._session.set(None)
cls._session_depth.set(TransactionalDepth.ON_EXIT.value)

@classmethod
def is_transaction_managed(cls) -> bool:
return cls._session_depth.get() > TransactionalDepth.OUTERMOST.value
2 changes: 2 additions & 0 deletions py_spring_model/py_spring_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from py_spring_model.core.commons import ApplicationFileGroups, PySpringModelProperties
from py_spring_model.core.model import PySpringModel
from py_spring_model.py_spring_model_rest.controller.session_controller import SessionController
from py_spring_model.repository.repository_base import RepositoryBase
from py_spring_model.py_spring_model_rest import PySpringModelRestService
from py_spring_model.py_spring_model_rest.controller.py_spring_model_rest_controller import (
Expand Down Expand Up @@ -149,6 +150,7 @@ def provide_py_spring_model() -> EntityProvider:
return PySpringModelProvider(
rest_controller_classes=[
# PySpringModelRestController
SessionController
],
component_classes=[
PySpringModelRestService,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Awaitable, Callable
from fastapi import Request, Response
from py_spring_model.core.session_context_holder import SessionContextHolder
from py_spring_core import RestController


async def session_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
"""
Middleware to ensure that the database session is properly cleaned up after each HTTP request.

This middleware works with context-based session management using ContextVar.
It guarantees that each request has its own isolated database session context, and
that any session stored in the context is properly closed after the request is handled.

It does NOT create or commit any transactions by itself. It is meant to be used
in combination with a decorator-based transaction manager (e.g., @Transactional)
which controls when to commit or rollback.

This middleware acts as a safety net:
- Ensures that the ContextVar-based session is cleared after each request.
- Prevents session leakage between requests in case of unexpected exceptions or unhandled paths.
- Complements nested transaction logic by guaranteeing session cleanup at request boundaries.

Use this middleware when:
- You are using context-local session handling (via contextvars).
- You want to ensure that long-lived or leaked sessions don't accumulate.
"""
try:
response = await call_next(request)
return response
finally:
SessionContextHolder.clear_session()

class SessionController(RestController):
def post_construct(self) -> None:
self.app.middleware("http")(session_middleware)


Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sqlmodel.sql.expression import SelectOfScalar

from py_spring_model.core.model import PySpringModel
from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional
from py_spring_model.repository.crud_repository import CrudRepository
from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.method_query_builder import (
_MetodQueryBuilder,
Expand Down Expand Up @@ -150,14 +151,15 @@ def _get_sql_statement(
query = query.where(filter_condition_stack.pop())
return query

@Transactional
def _session_execute(self, statement: SelectOfScalar, is_one_result: bool) -> Any:
with PySpringModel.create_session() as session:
logger.debug(f"Executing query: \n{str(statement)}")
result = (
session.exec(statement).first()
if is_one_result
else session.exec(statement).fetchall()
)
session = SessionContextHolder.get_or_create_session()
logger.debug(f"Executing query: \n{str(statement)}")
result = (
session.exec(statement).first()
if is_one_result
else session.exec(statement).fetchall()
)
return result

def post_construct(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from py_spring_core import Component

from py_spring_model import PySpringModel
from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional

ID = TypeVar("ID", int, UUID)
ModelT = TypeVar("ModelT", bound=PySpringModel)
Expand All @@ -20,44 +21,48 @@ class PySpringModelRestService(Component):
5. Updating an existing model,
6. Deleting a model by ID.
"""

def get_all_models(self) -> dict[str, type[PySpringModel]]:
return PySpringModel.get_model_lookup()

@Transactional
def get(self, model_type: Type[ModelT], id: ID) -> Optional[ModelT]:
with PySpringModel.create_managed_session() as session:
return session.get(model_type, id) # type: ignore
session = SessionContextHolder.get_or_create_session()
return session.get(model_type, id) # type: ignore

@Transactional
def get_all_by_ids(self, model_type: Type[ModelT], ids: list[ID]) -> list[ModelT]:
with PySpringModel.create_managed_session() as session:
return session.query(model_type).filter(model_type.id.in_(ids)).all() # type: ignore

session = SessionContextHolder.get_or_create_session()
return session.query(model_type).filter(model_type.id.in_(ids)).all() # type: ignore
@Transactional
def get_all(
self, model_type: Type[ModelT], limit: int, offset: int
) -> list[ModelT]:
with PySpringModel.create_managed_session() as session:
return session.query(model_type).offset(offset).limit(limit).all()
session = SessionContextHolder.get_or_create_session()
return session.query(model_type).offset(offset).limit(limit).all()

@Transactional
def create(self, model: ModelT) -> ModelT:
with PySpringModel.create_managed_session() as session:
session.add(model)
return model
session = SessionContextHolder.get_or_create_session()
session.add(model)
return model

@Transactional
def update(self, id: ID, model: ModelT) -> Optional[ModelT]:
with PySpringModel.create_managed_session() as session:
model_type = type(model)
primary_keys = PySpringModel.get_primary_key_columns(model_type)
optional_model = session.get(model_type, id) # type: ignore
if optional_model is None:
return

for key, value in model.model_dump().items():
if key in primary_keys:
continue
setattr(optional_model, key, value)
session.add(optional_model)
session = SessionContextHolder.get_or_create_session()
model_type = type(model)
primary_keys = PySpringModel.get_primary_key_columns(model_type)
optional_model = session.get(model_type, id) # type: ignore
if optional_model is None:
return

for key, value in model.model_dump().items():
if key in primary_keys:
continue
setattr(optional_model, key, value)
session.add(optional_model)

@Transactional
def delete(self, model_type: Type[ModelT], id: ID) -> None:
with PySpringModel.create_managed_session() as session:
session.query(model_type).filter(model_type.id == id).delete() # type: ignore
session.commit()
session = SessionContextHolder.get_or_create_session()
session.query(model_type).filter(model_type.id == id).delete() # type: ignore
session.commit()
Loading