1+ from contextvars import ContextVar
2+ from enum import IntEnum
3+ from functools import wraps
4+ from typing import Any , Callable , ClassVar , Optional , ParamSpec , TypeVar
5+
6+ from py_spring_model .core .model import PySpringModel
7+ from py_spring_model .core .py_spring_session import PySpringSession
8+
9+ class TransactionalDepth (IntEnum ):
10+ OUTERMOST = 1
11+ ON_EXIT = 0
12+
13+
14+ P = ParamSpec ("P" )
15+ RT = TypeVar ("RT" )
16+
17+ def Transactional (func : Callable [P , RT ]) -> Callable [P , RT ]:
18+ """
19+ Decorator for managing database transactions in a nested-safe manner.
20+
21+ This decorator ensures that:
22+ - A new session is created only if there is no active session (i.e., outermost transaction).
23+ - The session is committed, rolled back, and closed only by the outermost function.
24+ - Nested transactional functions share the same session and do not interfere with the commit/rollback behavior.
25+
26+ Behavior Summary:
27+ - If this function is the outermost @Transactional in the call stack:
28+ - A new session is created.
29+ - On success, the session is committed.
30+ - On failure, the session is rolled back.
31+ - The session is closed after execution.
32+ - If this function is called within an existing transaction:
33+ - The existing session is reused.
34+ - No commit, rollback, or close is performed (delegated to the outermost function).
35+
36+ Example:
37+ @Transactional
38+ def outer_operation():
39+ create_user()
40+ update_account()
41+
42+ @Transactional
43+ def create_user():
44+ db.session.add(User(...)) # Uses same session as outer_operation
45+
46+ @Transactional
47+ def update_account():
48+ db.session.add(Account(...)) # Uses same session as outer_operation
49+
50+ Only outer_operation will commit or rollback.
51+ If create_user() or update_account() raises an exception,
52+ the whole transaction will be rolled back.
53+ """
54+ @wraps (func )
55+ def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> RT :
56+ # Increment session depth and get session
57+ session_depth = SessionContextHolder .enter_session ()
58+ session = SessionContextHolder .get_or_create_session ()
59+ try :
60+ result = func (* args , ** kwargs )
61+ # Only commit at the outermost level (session_depth == 1)
62+ if session_depth == TransactionalDepth .OUTERMOST .value :
63+ session .commit ()
64+ return result
65+ except Exception as error :
66+ # Only rollback at the outermost level (session_depth == 1)
67+ if session_depth == TransactionalDepth .OUTERMOST .value :
68+ session .rollback ()
69+ raise error
70+ finally :
71+ # Decrement depth and clean up session if needed
72+ SessionContextHolder .exit_session ()
73+ return wrapper
74+
75+ class SessionContextHolder :
76+ """
77+ A context holder for the session with explicit depth tracking.
78+ This is used to store the session in a context variable so that it can be accessed by the query service.
79+ The depth counter ensures that only the outermost transaction manages commit/rollback operations.
80+ """
81+ _session : ClassVar [ContextVar [Optional [PySpringSession ]]] = ContextVar ("session" , default = None )
82+ _session_depth : ClassVar [ContextVar [int ]] = ContextVar ("session_depth" , default = 0 )
83+
84+ @classmethod
85+ def get_or_create_session (cls ) -> PySpringSession :
86+ optional_session = cls ._session .get ()
87+ if optional_session is None :
88+ session = PySpringModel .create_session ()
89+ cls ._session .set (session )
90+ return session
91+ return optional_session
92+
93+ @classmethod
94+ def has_session (cls ) -> bool :
95+ return cls ._session .get () is not None
96+
97+ @classmethod
98+ def get_session_depth (cls ) -> int :
99+ """Get the current session depth."""
100+ return cls ._session_depth .get ()
101+
102+ @classmethod
103+ def enter_session (cls ) -> int :
104+ """
105+ Enter a new session context and increment the depth counter.
106+ Returns the new depth level.
107+ """
108+ current_depth = cls ._session_depth .get ()
109+ new_depth = current_depth + 1
110+ cls ._session_depth .set (new_depth )
111+ return new_depth
112+
113+ @classmethod
114+ def exit_session (cls ) -> int :
115+ """
116+ Exit the current session context and decrement the depth counter.
117+ If depth reaches 0, clear the session.
118+ Returns the new depth level.
119+ """
120+ current_depth = cls ._session_depth .get ()
121+ new_depth = max (0 , current_depth - 1 ) # Prevent negative depth
122+ cls ._session_depth .set (new_depth )
123+
124+ # Clear session only when depth reaches 0 (outermost level)
125+ if new_depth == TransactionalDepth .ON_EXIT .value :
126+ cls .clear_session ()
127+
128+ return new_depth
129+
130+ @classmethod
131+ def clear_session (cls ):
132+ """Clear the session and reset depth to 0."""
133+ session = cls ._session .get ()
134+ if session is not None :
135+ session .close ()
136+ cls ._session .set (None )
137+ cls ._session_depth .set (TransactionalDepth .ON_EXIT .value )
138+
139+ @classmethod
140+ def is_transaction_managed (cls ) -> bool :
141+ return cls ._session_depth .get () > TransactionalDepth .OUTERMOST .value
0 commit comments