Skip to content

Commit c761f5d

Browse files
authored
Implement Context-based Transactional Session Management (#10)
1 parent da4ccde commit c761f5d

14 files changed

+841
-118
lines changed

py_spring_model/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
11
from py_spring_model.core.model import PySpringModel, Field
2+
from py_spring_model.core.session_context_holder import SessionContextHolder
23
from py_spring_model.py_spring_model_provider import provide_py_spring_model
34
from py_spring_model.repository.crud_repository import CrudRepository
45
from py_spring_model.repository.repository_base import RepositoryBase
56
from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.crud_repository_implementation_service import SkipAutoImplmentation
6-
from py_spring_model.py_spring_model_rest.service.query_service.query import Query
7+
from py_spring_model.py_spring_model_rest.service.query_service.query import Query
8+
9+
10+
__all__ = [
11+
"PySpringModel",
12+
"Field",
13+
"SessionContextHolder",
14+
"provide_py_spring_model",
15+
"CrudRepository",
16+
"RepositoryBase",
17+
"SkipAutoImplmentation",
18+
"Query",
19+
]
20+
21+
__version__ = "0.1.0"

py_spring_model/core/model.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import contextlib
2-
from typing_extensions import Self
3-
from typing import ClassVar, Iterator, Optional, Type
2+
from typing import ClassVar, Iterator, Optional, Type, TypeVar
43
from loguru import logger
54
from sqlalchemy import Engine, MetaData
65
from sqlalchemy.engine.base import Connection
76
from sqlmodel import SQLModel, Field
87

98
from py_spring_model.core.py_spring_session import PySpringSession
109

10+
T = TypeVar("T", bound="PySpringModel")
11+
1112

1213
class PySpringModel(SQLModel):
1314
"""
@@ -81,7 +82,7 @@ def get_model_lookup(cls) -> dict[str, type["PySpringModel"]]:
8182
raise ValueError("[MODEL_LOOKUP NOT SET] Model lookup is not set")
8283
return {str(_model.__tablename__): _model for _model in cls._models}
8384

84-
def clone(self) -> Self:
85+
def clone(self: T) -> T:
8586
return self.model_validate_json(self.model_dump_json())
8687

8788
@classmethod
@@ -105,11 +106,7 @@ def create_managed_session(cls, should_commit: bool = True) -> Iterator[PySpring
105106
logger.debug("[MANAGED SESSION COMMIT] Session committing...")
106107
if should_commit:
107108
session.commit()
108-
logger.debug(
109-
"[MANAGED SESSION COMMIT] Session committed, refreshing instances..."
110-
)
111-
session.refresh_current_session_instances()
112-
logger.success("[MANAGED SESSION COMMIT] Session committed.")
109+
logger.success("[MANAGED SESSION COMMIT] Session committed.")
113110
except Exception as error:
114111
logger.error(error)
115112
logger.error("[MANAGED SESSION ROLLBACK] Session rolling back...")

py_spring_model/core/py_spring_session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Iterable
22

3+
from loguru import logger
34
from sqlmodel import Session, SQLModel
45

56

@@ -33,3 +34,12 @@ def add_all(self, instances: Iterable[SQLModel]) -> None:
3334
def refresh_current_session_instances(self) -> None:
3435
for instance in self.current_session_instance:
3536
self.refresh(instance)
37+
38+
def commit(self) -> None:
39+
# Import here to avoid circular import
40+
from py_spring_model.core.session_context_holder import SessionContextHolder
41+
if SessionContextHolder.is_transaction_managed():
42+
logger.warning("Commiting a transaction that is currently being managed by the outermost transaction is strongly discouraged...")
43+
return
44+
super().commit()
45+
self.refresh_current_session_instances()
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from contextvars import ContextVar
2+
from enum import IntEnum
3+
from functools import wraps
4+
from typing import Any, Callable, ClassVar, Optional, ParamSpec, TypeVar
5+
6+
from py_spring_model.core.model import PySpringModel
7+
from py_spring_model.core.py_spring_session import PySpringSession
8+
9+
class TransactionalDepth(IntEnum):
10+
OUTERMOST = 1
11+
ON_EXIT = 0
12+
13+
14+
P = ParamSpec("P")
15+
RT = TypeVar("RT")
16+
17+
def Transactional(func: Callable[P, RT]) -> Callable[P, RT]:
18+
"""
19+
Decorator for managing database transactions in a nested-safe manner.
20+
21+
This decorator ensures that:
22+
- A new session is created only if there is no active session (i.e., outermost transaction).
23+
- The session is committed, rolled back, and closed only by the outermost function.
24+
- Nested transactional functions share the same session and do not interfere with the commit/rollback behavior.
25+
26+
Behavior Summary:
27+
- If this function is the outermost @Transactional in the call stack:
28+
- A new session is created.
29+
- On success, the session is committed.
30+
- On failure, the session is rolled back.
31+
- The session is closed after execution.
32+
- If this function is called within an existing transaction:
33+
- The existing session is reused.
34+
- No commit, rollback, or close is performed (delegated to the outermost function).
35+
36+
Example:
37+
@Transactional
38+
def outer_operation():
39+
create_user()
40+
update_account()
41+
42+
@Transactional
43+
def create_user():
44+
db.session.add(User(...)) # Uses same session as outer_operation
45+
46+
@Transactional
47+
def update_account():
48+
db.session.add(Account(...)) # Uses same session as outer_operation
49+
50+
Only outer_operation will commit or rollback.
51+
If create_user() or update_account() raises an exception,
52+
the whole transaction will be rolled back.
53+
"""
54+
@wraps(func)
55+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> RT:
56+
# Increment session depth and get session
57+
session_depth = SessionContextHolder.enter_session()
58+
session = SessionContextHolder.get_or_create_session()
59+
try:
60+
result = func(*args, **kwargs)
61+
# Only commit at the outermost level (session_depth == 1)
62+
if session_depth == TransactionalDepth.OUTERMOST.value:
63+
session.commit()
64+
return result
65+
except Exception as error:
66+
# Only rollback at the outermost level (session_depth == 1)
67+
if session_depth == TransactionalDepth.OUTERMOST.value:
68+
session.rollback()
69+
raise error
70+
finally:
71+
# Decrement depth and clean up session if needed
72+
SessionContextHolder.exit_session()
73+
return wrapper
74+
75+
class SessionContextHolder:
76+
"""
77+
A context holder for the session with explicit depth tracking.
78+
This is used to store the session in a context variable so that it can be accessed by the query service.
79+
The depth counter ensures that only the outermost transaction manages commit/rollback operations.
80+
"""
81+
_session: ClassVar[ContextVar[Optional[PySpringSession]]] = ContextVar("session", default=None)
82+
_session_depth: ClassVar[ContextVar[int]] = ContextVar("session_depth", default=0)
83+
84+
@classmethod
85+
def get_or_create_session(cls) -> PySpringSession:
86+
optional_session = cls._session.get()
87+
if optional_session is None:
88+
session = PySpringModel.create_session()
89+
cls._session.set(session)
90+
return session
91+
return optional_session
92+
93+
@classmethod
94+
def has_session(cls) -> bool:
95+
return cls._session.get() is not None
96+
97+
@classmethod
98+
def get_session_depth(cls) -> int:
99+
"""Get the current session depth."""
100+
return cls._session_depth.get()
101+
102+
@classmethod
103+
def enter_session(cls) -> int:
104+
"""
105+
Enter a new session context and increment the depth counter.
106+
Returns the new depth level.
107+
"""
108+
current_depth = cls._session_depth.get()
109+
new_depth = current_depth + 1
110+
cls._session_depth.set(new_depth)
111+
return new_depth
112+
113+
@classmethod
114+
def exit_session(cls) -> int:
115+
"""
116+
Exit the current session context and decrement the depth counter.
117+
If depth reaches 0, clear the session.
118+
Returns the new depth level.
119+
"""
120+
current_depth = cls._session_depth.get()
121+
new_depth = max(0, current_depth - 1) # Prevent negative depth
122+
cls._session_depth.set(new_depth)
123+
124+
# Clear session only when depth reaches 0 (outermost level)
125+
if new_depth == TransactionalDepth.ON_EXIT.value:
126+
cls.clear_session()
127+
128+
return new_depth
129+
130+
@classmethod
131+
def clear_session(cls):
132+
"""Clear the session and reset depth to 0."""
133+
session = cls._session.get()
134+
if session is not None:
135+
session.close()
136+
cls._session.set(None)
137+
cls._session_depth.set(TransactionalDepth.ON_EXIT.value)
138+
139+
@classmethod
140+
def is_transaction_managed(cls) -> bool:
141+
return cls._session_depth.get() > TransactionalDepth.OUTERMOST.value

py_spring_model/py_spring_model_provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from py_spring_model.core.commons import ApplicationFileGroups, PySpringModelProperties
1212
from py_spring_model.core.model import PySpringModel
13+
from py_spring_model.py_spring_model_rest.controller.session_controller import SessionController
1314
from py_spring_model.repository.repository_base import RepositoryBase
1415
from py_spring_model.py_spring_model_rest import PySpringModelRestService
1516
from py_spring_model.py_spring_model_rest.controller.py_spring_model_rest_controller import (
@@ -149,6 +150,7 @@ def provide_py_spring_model() -> EntityProvider:
149150
return PySpringModelProvider(
150151
rest_controller_classes=[
151152
# PySpringModelRestController
153+
SessionController
152154
],
153155
component_classes=[
154156
PySpringModelRestService,
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Awaitable, Callable
2+
from fastapi import Request, Response
3+
from py_spring_model.core.session_context_holder import SessionContextHolder
4+
from py_spring_core import RestController
5+
6+
7+
async def session_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
8+
"""
9+
Middleware to ensure that the database session is properly cleaned up after each HTTP request.
10+
11+
This middleware works with context-based session management using ContextVar.
12+
It guarantees that each request has its own isolated database session context, and
13+
that any session stored in the context is properly closed after the request is handled.
14+
15+
It does NOT create or commit any transactions by itself. It is meant to be used
16+
in combination with a decorator-based transaction manager (e.g., @Transactional)
17+
which controls when to commit or rollback.
18+
19+
This middleware acts as a safety net:
20+
- Ensures that the ContextVar-based session is cleared after each request.
21+
- Prevents session leakage between requests in case of unexpected exceptions or unhandled paths.
22+
- Complements nested transaction logic by guaranteeing session cleanup at request boundaries.
23+
24+
Use this middleware when:
25+
- You are using context-local session handling (via contextvars).
26+
- You want to ensure that long-lived or leaked sessions don't accumulate.
27+
"""
28+
try:
29+
response = await call_next(request)
30+
return response
31+
finally:
32+
SessionContextHolder.clear_session()
33+
34+
class SessionController(RestController):
35+
def post_construct(self) -> None:
36+
self.app.middleware("http")(session_middleware)
37+
38+

py_spring_model/py_spring_model_rest/service/curd_repository_implementation_service/crud_repository_implementation_service.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sqlmodel.sql.expression import SelectOfScalar
2020

2121
from py_spring_model.core.model import PySpringModel
22+
from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional
2223
from py_spring_model.repository.crud_repository import CrudRepository
2324
from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.method_query_builder import (
2425
_MetodQueryBuilder,
@@ -150,14 +151,15 @@ def _get_sql_statement(
150151
query = query.where(filter_condition_stack.pop())
151152
return query
152153

154+
@Transactional
153155
def _session_execute(self, statement: SelectOfScalar, is_one_result: bool) -> Any:
154-
with PySpringModel.create_session() as session:
155-
logger.debug(f"Executing query: \n{str(statement)}")
156-
result = (
157-
session.exec(statement).first()
158-
if is_one_result
159-
else session.exec(statement).fetchall()
160-
)
156+
session = SessionContextHolder.get_or_create_session()
157+
logger.debug(f"Executing query: \n{str(statement)}")
158+
result = (
159+
session.exec(statement).first()
160+
if is_one_result
161+
else session.exec(statement).fetchall()
162+
)
161163
return result
162164

163165
def post_construct(self) -> None:

py_spring_model/py_spring_model_rest/service/py_spring_model_rest_service.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from py_spring_core import Component
55

66
from py_spring_model import PySpringModel
7+
from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional
78

89
ID = TypeVar("ID", int, UUID)
910
ModelT = TypeVar("ModelT", bound=PySpringModel)
@@ -20,44 +21,48 @@ class PySpringModelRestService(Component):
2021
5. Updating an existing model,
2122
6. Deleting a model by ID.
2223
"""
23-
2424
def get_all_models(self) -> dict[str, type[PySpringModel]]:
2525
return PySpringModel.get_model_lookup()
2626

27+
@Transactional
2728
def get(self, model_type: Type[ModelT], id: ID) -> Optional[ModelT]:
28-
with PySpringModel.create_managed_session() as session:
29-
return session.get(model_type, id) # type: ignore
29+
session = SessionContextHolder.get_or_create_session()
30+
return session.get(model_type, id) # type: ignore
3031

32+
@Transactional
3133
def get_all_by_ids(self, model_type: Type[ModelT], ids: list[ID]) -> list[ModelT]:
32-
with PySpringModel.create_managed_session() as session:
33-
return session.query(model_type).filter(model_type.id.in_(ids)).all() # type: ignore
34-
34+
session = SessionContextHolder.get_or_create_session()
35+
return session.query(model_type).filter(model_type.id.in_(ids)).all() # type: ignore
36+
@Transactional
3537
def get_all(
3638
self, model_type: Type[ModelT], limit: int, offset: int
3739
) -> list[ModelT]:
38-
with PySpringModel.create_managed_session() as session:
39-
return session.query(model_type).offset(offset).limit(limit).all()
40+
session = SessionContextHolder.get_or_create_session()
41+
return session.query(model_type).offset(offset).limit(limit).all()
4042

43+
@Transactional
4144
def create(self, model: ModelT) -> ModelT:
42-
with PySpringModel.create_managed_session() as session:
43-
session.add(model)
44-
return model
45+
session = SessionContextHolder.get_or_create_session()
46+
session.add(model)
47+
return model
4548

49+
@Transactional
4650
def update(self, id: ID, model: ModelT) -> Optional[ModelT]:
47-
with PySpringModel.create_managed_session() as session:
48-
model_type = type(model)
49-
primary_keys = PySpringModel.get_primary_key_columns(model_type)
50-
optional_model = session.get(model_type, id) # type: ignore
51-
if optional_model is None:
52-
return
53-
54-
for key, value in model.model_dump().items():
55-
if key in primary_keys:
56-
continue
57-
setattr(optional_model, key, value)
58-
session.add(optional_model)
51+
session = SessionContextHolder.get_or_create_session()
52+
model_type = type(model)
53+
primary_keys = PySpringModel.get_primary_key_columns(model_type)
54+
optional_model = session.get(model_type, id) # type: ignore
55+
if optional_model is None:
56+
return
57+
58+
for key, value in model.model_dump().items():
59+
if key in primary_keys:
60+
continue
61+
setattr(optional_model, key, value)
62+
session.add(optional_model)
5963

64+
@Transactional
6065
def delete(self, model_type: Type[ModelT], id: ID) -> None:
61-
with PySpringModel.create_managed_session() as session:
62-
session.query(model_type).filter(model_type.id == id).delete() # type: ignore
63-
session.commit()
66+
session = SessionContextHolder.get_or_create_session()
67+
session.query(model_type).filter(model_type.id == id).delete() # type: ignore
68+
session.commit()

0 commit comments

Comments
 (0)