1
+ import pytest
2
+ from unittest .mock import patch
3
+ from sqlalchemy import create_engine , text
4
+ from sqlmodel import Field , SQLModel
5
+
6
+ from py_spring_model import PySpringModel
7
+ from py_spring_model .core .session_context_holder import SessionContextHolder , Transactional
8
+
9
+
10
+ class TransactionalTestUser (PySpringModel , table = True ):
11
+ """Test model for transactional operations"""
12
+ id : int = Field (default = None , primary_key = True )
13
+ name : str
14
+ email : str
15
+ age : int = Field (default = 0 )
16
+
17
+
18
+ class TestTransactionalDecorator :
19
+ """Test suite for the @Transactional decorator"""
20
+
21
+ @pytest .fixture (autouse = True )
22
+ def setup_and_teardown (self ):
23
+ """Set up test environment with in-memory SQLite database"""
24
+ self .engine = create_engine ("sqlite:///:memory:" , echo = False )
25
+ PySpringModel .set_engine (self .engine )
26
+ PySpringModel .set_metadata (SQLModel .metadata )
27
+ PySpringModel .set_models ([TransactionalTestUser ])
28
+
29
+ # Clear any existing session
30
+ SessionContextHolder .clear_session ()
31
+
32
+ SQLModel .metadata .create_all (self .engine )
33
+
34
+ def teardown_method (self ):
35
+ """Tear down test environment"""
36
+ SQLModel .metadata .drop_all (self .engine )
37
+ SessionContextHolder .clear_session ()
38
+ PySpringModel ._engine = None
39
+ PySpringModel ._metadata = None
40
+ PySpringModel ._connection = None
41
+
42
+ def test_single_transactional_success (self ):
43
+ """Test that a single @Transactional function commits successfully"""
44
+
45
+ @Transactional
46
+ def create_user ():
47
+ session = SessionContextHolder .get_or_create_session ()
48
+ user = TransactionalTestUser (name = "John Doe" , email = "john@example.com" , age = 30 )
49
+ session .add (user )
50
+ session .flush () # To get the ID
51
+ return user
52
+
53
+ # Execute the transactional function
54
+ result = create_user ()
55
+
56
+ # Verify the user was created and committed
57
+ assert result .name == "John Doe"
58
+ assert result .email == "john@example.com"
59
+ assert result .age == 30
60
+
61
+ # Verify session is cleared after transaction
62
+ assert not SessionContextHolder .has_session ()
63
+
64
+ # Verify data persisted to database
65
+ with PySpringModel .create_managed_session () as session :
66
+ users = session .execute (text ("SELECT * FROM transactionaltestuser" )).fetchall ()
67
+ assert len (users ) == 1
68
+ assert users [0 ].name == "John Doe"
69
+
70
+ def test_single_transactional_rollback (self ):
71
+ """Test that a single @Transactional function rolls back on exception"""
72
+
73
+ @Transactional
74
+ def create_user_with_error ():
75
+ session = SessionContextHolder .get_or_create_session ()
76
+ user = TransactionalTestUser (name = "Jane Doe" , email = "jane@example.com" , age = 25 )
77
+ session .add (user )
78
+ session .flush ()
79
+ raise ValueError ("Simulated error" )
80
+
81
+ # Execute the transactional function and expect exception
82
+ with pytest .raises (ValueError , match = "Simulated error" ):
83
+ create_user_with_error ()
84
+
85
+ # Verify session is cleared after rollback
86
+ assert not SessionContextHolder .has_session ()
87
+
88
+ # Verify no data persisted to database
89
+ with PySpringModel .create_managed_session () as session :
90
+ users = session .execute (text ("SELECT * FROM transactionaltestuser" )).fetchall ()
91
+ assert len (users ) == 0
92
+
93
+ def test_nested_transactional_success (self ):
94
+ """Test that nested @Transactional functions share the same session and commit at top level"""
95
+
96
+ @Transactional
97
+ def create_user (name : str , email : str ):
98
+ session = SessionContextHolder .get_or_create_session ()
99
+ user = TransactionalTestUser (name = name , email = email , age = 30 )
100
+ session .add (user )
101
+ session .flush ()
102
+ return user
103
+
104
+ @Transactional
105
+ def update_user_age (user_id : int , new_age : int ):
106
+ session = SessionContextHolder .get_or_create_session ()
107
+ session .execute (text (f"UPDATE transactionaltestuser SET age = { new_age } WHERE id = { user_id } " ))
108
+
109
+ @Transactional
110
+ def create_and_update_user ():
111
+ # This should share the same session with nested calls
112
+ user = create_user ("Alice Smith" , "alice@example.com" )
113
+ update_user_age (user .id , 35 )
114
+ return user
115
+
116
+ # Execute the outer transactional function
117
+ result = create_and_update_user ()
118
+
119
+ # Verify session is cleared after transaction
120
+ assert not SessionContextHolder .has_session ()
121
+
122
+ # Verify both operations were committed
123
+ with PySpringModel .create_managed_session () as session :
124
+ users = session .execute (text ("SELECT * FROM transactionaltestuser" )).fetchall ()
125
+ assert len (users ) == 1
126
+ assert users [0 ].name == "Alice Smith"
127
+ assert users [0 ].age == 35 # Updated by nested function
128
+
129
+ def test_nested_transactional_rollback_from_inner (self ):
130
+ """Test that exception in nested @Transactional causes rollback at top level"""
131
+
132
+ @Transactional
133
+ def create_user (name : str , email : str ):
134
+ session = SessionContextHolder .get_or_create_session ()
135
+ user = TransactionalTestUser (name = name , email = email , age = 30 )
136
+ session .add (user )
137
+ session .flush ()
138
+ return user
139
+
140
+ @Transactional
141
+ def update_user_with_error (user_id : int ):
142
+ session = SessionContextHolder .get_or_create_session ()
143
+ session .execute (text (f"UPDATE transactionaltestuser SET age = 40 WHERE id = { user_id } " ))
144
+ raise RuntimeError ("Update failed" )
145
+
146
+ @Transactional
147
+ def create_and_update_user_with_error ():
148
+ user = create_user ("Bob Johnson" , "bob@example.com" )
149
+ update_user_with_error (user .id ) # This will raise an exception
150
+ return user
151
+
152
+ # Execute and expect exception
153
+ with pytest .raises (RuntimeError , match = "Update failed" ):
154
+ create_and_update_user_with_error ()
155
+
156
+ # Verify session is cleared after rollback
157
+ assert not SessionContextHolder .has_session ()
158
+
159
+ # Verify no data persisted (everything rolled back)
160
+ with PySpringModel .create_managed_session () as session :
161
+ users = session .execute (text ("SELECT * FROM transactionaltestuser" )).fetchall ()
162
+ assert len (users ) == 0
163
+
164
+ def test_nested_transactional_rollback_from_outer (self ):
165
+ """Test that exception in outer @Transactional causes rollback after nested calls"""
166
+
167
+ @Transactional
168
+ def create_user (name : str , email : str ):
169
+ session = SessionContextHolder .get_or_create_session ()
170
+ user = TransactionalTestUser (name = name , email = email , age = 30 )
171
+ session .add (user )
172
+ session .flush ()
173
+ return user
174
+
175
+ @Transactional
176
+ def update_user_age (user_id : int , new_age : int ):
177
+ session = SessionContextHolder .get_or_create_session ()
178
+ session .execute (text (f"UPDATE transactionaltestuser SET age = { new_age } WHERE id = { user_id } " ))
179
+
180
+ @Transactional
181
+ def create_update_and_fail ():
182
+ user = create_user ("Charlie Brown" , "charlie@example.com" )
183
+ update_user_age (user .id , 45 )
184
+ # Both nested operations succeeded, but outer fails
185
+ raise Exception ("Outer operation failed" )
186
+
187
+ # Execute and expect exception
188
+ with pytest .raises (Exception , match = "Outer operation failed" ):
189
+ create_update_and_fail ()
190
+
191
+ # Verify session is cleared after rollback
192
+ assert not SessionContextHolder .has_session ()
193
+
194
+ # Verify no data persisted (everything rolled back)
195
+ with PySpringModel .create_managed_session () as session :
196
+ users = session .execute (text ("SELECT * FROM transactionaltestuser" )).fetchall ()
197
+ assert len (users ) == 0
198
+
199
+ def test_session_sharing_across_nested_transactions (self ):
200
+ """Test that nested @Transactional functions share the same session instance"""
201
+
202
+ captured_sessions = []
203
+
204
+ @Transactional
205
+ def inner_function ():
206
+ session = SessionContextHolder .get_or_create_session ()
207
+ captured_sessions .append (session )
208
+ user = TransactionalTestUser (name = "Inner User" , email = "inner@example.com" , age = 25 )
209
+ session .add (user )
210
+ session .flush ()
211
+ return session
212
+
213
+ @Transactional
214
+ def middle_function ():
215
+ session = SessionContextHolder .get_or_create_session ()
216
+ captured_sessions .append (session )
217
+ inner_session = inner_function ()
218
+ return session , inner_session
219
+
220
+ @Transactional
221
+ def outer_function ():
222
+ session = SessionContextHolder .get_or_create_session ()
223
+ captured_sessions .append (session )
224
+ middle_session , inner_session = middle_function ()
225
+ return session , middle_session , inner_session
226
+
227
+ # Execute nested transactions
228
+ outer_session , middle_session , inner_session = outer_function ()
229
+
230
+ # Verify all functions used the same session instance
231
+ assert len (captured_sessions ) == 3
232
+ assert captured_sessions [0 ] is captured_sessions [1 ] # outer and middle
233
+ assert captured_sessions [1 ] is captured_sessions [2 ] # middle and inner
234
+ assert outer_session is middle_session is inner_session
235
+
236
+ # Verify session is cleared after transaction
237
+ assert not SessionContextHolder .has_session ()
238
+
239
+ def test_transactional_session_context_isolation (self ):
240
+ """Test that @Transactional properly isolates session context"""
241
+
242
+ @Transactional
243
+ def first_transaction ():
244
+ session = SessionContextHolder .get_or_create_session ()
245
+ user1 = TransactionalTestUser (name = "User 1" , email = "user1@example.com" , age = 30 )
246
+ session .add (user1 )
247
+ session .flush ()
248
+ return user1 .id
249
+
250
+ @Transactional
251
+ def second_transaction ():
252
+ session = SessionContextHolder .get_or_create_session ()
253
+ user2 = TransactionalTestUser (name = "User 2" , email = "user2@example.com" , age = 35 )
254
+ session .add (user2 )
255
+ session .flush ()
256
+ return user2 .id
257
+
258
+ # Execute separate transactions
259
+ first_transaction ()
260
+ second_transaction ()
261
+
262
+ # Verify sessions were properly isolated
263
+ assert not SessionContextHolder .has_session ()
264
+
265
+ # Verify both transactions committed separately
266
+ with PySpringModel .create_managed_session () as session :
267
+ users = session .execute (text ("SELECT * FROM transactionaltestuser ORDER BY id" )).fetchall ()
268
+ assert len (users ) == 2
269
+ assert users [0 ].name == "User 1"
270
+ assert users [1 ].name == "User 2"
271
+
272
+ def test_transactional_preserves_function_metadata (self ):
273
+ """Test that @Transactional preserves original function metadata"""
274
+
275
+ @Transactional
276
+ def documented_function (param1 : str , param2 : int = 10 ) -> str :
277
+ """This is a documented function with parameters."""
278
+ return f"{ param1 } _{ param2 } "
279
+
280
+ # Verify function metadata is preserved
281
+ assert documented_function .__name__ == "documented_function"
282
+ assert documented_function .__doc__ is not None and "documented function" in documented_function .__doc__
283
+
284
+ # Verify function still works correctly
285
+ result = documented_function ("test" , 20 )
286
+ assert result == "test_20"
287
+
288
+ def test_transactional_commit_rollback_behavior (self ):
289
+ """Test the core commit/rollback behavior of nested transactions"""
290
+
291
+ commit_calls = []
292
+ rollback_calls = []
293
+
294
+ # Mock session to track commit/rollback calls
295
+ original_session_class = PySpringModel .create_session
296
+
297
+ def mock_create_session ():
298
+ session = original_session_class ()
299
+ original_commit = session .commit
300
+ original_rollback = session .rollback
301
+
302
+ def mock_commit ():
303
+ commit_calls .append ("commit" )
304
+ return original_commit ()
305
+
306
+ def mock_rollback ():
307
+ rollback_calls .append ("rollback" )
308
+ return original_rollback ()
309
+
310
+ session .commit = mock_commit
311
+ session .rollback = mock_rollback
312
+ return session
313
+
314
+ @Transactional
315
+ def inner_operation ():
316
+ session = SessionContextHolder .get_or_create_session ()
317
+ user = TransactionalTestUser (name = "Test" , email = "test@example.com" , age = 30 )
318
+ session .add (user )
319
+ session .flush ()
320
+
321
+ @Transactional
322
+ def outer_operation ():
323
+ inner_operation ()
324
+
325
+ # Test successful nested transaction
326
+ with patch .object (PySpringModel , 'create_session' , side_effect = mock_create_session ):
327
+ outer_operation ()
328
+
329
+ # Only the outermost transaction should commit
330
+ assert len (commit_calls ) == 1
331
+ assert len (rollback_calls ) == 0
332
+
333
+ # Reset counters
334
+ commit_calls .clear ()
335
+ rollback_calls .clear ()
336
+
337
+ @Transactional
338
+ def inner_operation_with_error ():
339
+ session = SessionContextHolder .get_or_create_session ()
340
+ user = TransactionalTestUser (name = "Test2" , email = "test2@example.com" , age = 25 )
341
+ session .add (user )
342
+ session .flush ()
343
+ raise ValueError ("Inner error" )
344
+
345
+ @Transactional
346
+ def outer_operation_with_error ():
347
+ inner_operation_with_error ()
348
+
349
+ # Test failed nested transaction
350
+ with patch .object (PySpringModel , 'create_session' , side_effect = mock_create_session ):
351
+ with pytest .raises (ValueError ):
352
+ outer_operation_with_error ()
353
+
354
+ # Only the outermost transaction should rollback
355
+ assert len (commit_calls ) == 0
356
+ assert len (rollback_calls ) == 1
0 commit comments