1
+ from contextvars import ContextVar
2
+ from enum import IntEnum
3
+ from functools import wraps
4
+ from typing import Any , Callable , ClassVar , Optional , ParamSpec , TypeVar
5
+
6
+ from py_spring_model .core .model import PySpringModel
7
+ from py_spring_model .core .py_spring_session import PySpringSession
8
+
9
+ class TransactionalDepth (IntEnum ):
10
+ OUTERMOST = 1
11
+ ON_EXIT = 0
12
+
13
+
14
+ P = ParamSpec ("P" )
15
+ RT = TypeVar ("RT" )
16
+
17
+ def Transactional (func : Callable [P , RT ]) -> Callable [P , RT ]:
18
+ """
19
+ Decorator for managing database transactions in a nested-safe manner.
20
+
21
+ This decorator ensures that:
22
+ - A new session is created only if there is no active session (i.e., outermost transaction).
23
+ - The session is committed, rolled back, and closed only by the outermost function.
24
+ - Nested transactional functions share the same session and do not interfere with the commit/rollback behavior.
25
+
26
+ Behavior Summary:
27
+ - If this function is the outermost @Transactional in the call stack:
28
+ - A new session is created.
29
+ - On success, the session is committed.
30
+ - On failure, the session is rolled back.
31
+ - The session is closed after execution.
32
+ - If this function is called within an existing transaction:
33
+ - The existing session is reused.
34
+ - No commit, rollback, or close is performed (delegated to the outermost function).
35
+
36
+ Example:
37
+ @Transactional
38
+ def outer_operation():
39
+ create_user()
40
+ update_account()
41
+
42
+ @Transactional
43
+ def create_user():
44
+ db.session.add(User(...)) # Uses same session as outer_operation
45
+
46
+ @Transactional
47
+ def update_account():
48
+ db.session.add(Account(...)) # Uses same session as outer_operation
49
+
50
+ Only outer_operation will commit or rollback.
51
+ If create_user() or update_account() raises an exception,
52
+ the whole transaction will be rolled back.
53
+ """
54
+ @wraps (func )
55
+ def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> RT :
56
+ # Increment session depth and get session
57
+ session_depth = SessionContextHolder .enter_session ()
58
+ session = SessionContextHolder .get_or_create_session ()
59
+ try :
60
+ result = func (* args , ** kwargs )
61
+ # Only commit at the outermost level (session_depth == 1)
62
+ if session_depth == TransactionalDepth .OUTERMOST .value :
63
+ session .commit ()
64
+ return result
65
+ except Exception as error :
66
+ # Only rollback at the outermost level (session_depth == 1)
67
+ if session_depth == TransactionalDepth .OUTERMOST .value :
68
+ session .rollback ()
69
+ raise error
70
+ finally :
71
+ # Decrement depth and clean up session if needed
72
+ SessionContextHolder .exit_session ()
73
+ return wrapper
74
+
75
+ class SessionContextHolder :
76
+ """
77
+ A context holder for the session with explicit depth tracking.
78
+ This is used to store the session in a context variable so that it can be accessed by the query service.
79
+ The depth counter ensures that only the outermost transaction manages commit/rollback operations.
80
+ """
81
+ _session : ClassVar [ContextVar [Optional [PySpringSession ]]] = ContextVar ("session" , default = None )
82
+ _session_depth : ClassVar [ContextVar [int ]] = ContextVar ("session_depth" , default = 0 )
83
+
84
+ @classmethod
85
+ def get_or_create_session (cls ) -> PySpringSession :
86
+ optional_session = cls ._session .get ()
87
+ if optional_session is None :
88
+ session = PySpringModel .create_session ()
89
+ cls ._session .set (session )
90
+ return session
91
+ return optional_session
92
+
93
+ @classmethod
94
+ def has_session (cls ) -> bool :
95
+ return cls ._session .get () is not None
96
+
97
+ @classmethod
98
+ def get_session_depth (cls ) -> int :
99
+ """Get the current session depth."""
100
+ return cls ._session_depth .get ()
101
+
102
+ @classmethod
103
+ def enter_session (cls ) -> int :
104
+ """
105
+ Enter a new session context and increment the depth counter.
106
+ Returns the new depth level.
107
+ """
108
+ current_depth = cls ._session_depth .get ()
109
+ new_depth = current_depth + 1
110
+ cls ._session_depth .set (new_depth )
111
+ return new_depth
112
+
113
+ @classmethod
114
+ def exit_session (cls ) -> int :
115
+ """
116
+ Exit the current session context and decrement the depth counter.
117
+ If depth reaches 0, clear the session.
118
+ Returns the new depth level.
119
+ """
120
+ current_depth = cls ._session_depth .get ()
121
+ new_depth = max (0 , current_depth - 1 ) # Prevent negative depth
122
+ cls ._session_depth .set (new_depth )
123
+
124
+ # Clear session only when depth reaches 0 (outermost level)
125
+ if new_depth == TransactionalDepth .ON_EXIT .value :
126
+ cls .clear_session ()
127
+
128
+ return new_depth
129
+
130
+ @classmethod
131
+ def clear_session (cls ):
132
+ """Clear the session and reset depth to 0."""
133
+ session = cls ._session .get ()
134
+ if session is not None :
135
+ session .close ()
136
+ cls ._session .set (None )
137
+ cls ._session_depth .set (TransactionalDepth .ON_EXIT .value )
138
+
139
+ @classmethod
140
+ def is_transaction_managed (cls ) -> bool :
141
+ return cls ._session_depth .get () > TransactionalDepth .OUTERMOST .value
0 commit comments