Skip to content

Commit 4699eaf

Browse files
feat: Enhanced @transactional decorator with smart session management
- Implement outermost transaction detection using SessionContextHolder.has_session() - Session lifecycle (commit/rollback/close) only managed by outermost @transactional - Nested @transactional methods reuse existing session without interference - Prevent premature session closure in nested transaction scenarios - Maintain transaction integrity across multiple decorated method calls
1 parent 6e2631e commit 4699eaf

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

py_spring_model/core/session_context_holder.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,36 @@
11
from contextvars import ContextVar
22
from functools import wraps
3-
from typing import ClassVar, Optional
3+
from typing import Any, Callable, ClassVar, Optional
44
from sqlmodel import Session
55

66
from py_spring_model.core.model import PySpringModel
77

8-
def Transactional(func):
8+
def Transactional(func: Callable[..., Any]) -> Callable[..., Any]:
99
"""
1010
A decorator that wraps a function and commits the session if the function is successful.
1111
If the function raises an exception, the session is rolled back.
1212
The session is then closed.
13+
If the function is the outermost function, the session is committed.
14+
If the function is not the outermost function, the session is not committed.
15+
If the function is not the outermost function, the session is not rolled back.
16+
If the function is not the outermost function, the session is not closed.
1317
"""
1418
@wraps(func)
1519
def wrapper(*args, **kwargs):
20+
is_outermost = not SessionContextHolder.has_session()
1621
session = SessionContextHolder.get_or_create_session()
1722
try:
1823
result = func(*args, **kwargs)
19-
session.commit()
24+
if is_outermost:
25+
session.commit()
2026
return result
2127
except Exception as error:
22-
session.rollback()
28+
if is_outermost:
29+
session.rollback()
2330
raise error
2431
finally:
25-
SessionContextHolder.clear_session()
32+
if is_outermost:
33+
SessionContextHolder.clear_session()
2634
return wrapper
2735

2836
class SessionContextHolder:
@@ -41,10 +49,14 @@ def get_or_create_session(cls) -> Session:
4149
cls._session.set(session)
4250
return session
4351
return optional_session
52+
53+
@classmethod
54+
def has_session(cls) -> bool:
55+
return cls._session.get() is not None
56+
4457
@classmethod
4558
def clear_session(cls):
46-
optional_session = cls._session.get()
47-
if optional_session is None:
48-
return
49-
optional_session.close()
59+
session = cls._session.get()
60+
if session is not None:
61+
session.close()
5062
cls._session.set(None)

0 commit comments

Comments
 (0)