1+ import pytest
2+ from unittest .mock import patch
3+ from sqlalchemy import create_engine , text
4+ from sqlmodel import Field , SQLModel
5+
6+ from py_spring_model import PySpringModel
7+ from py_spring_model .core .session_context_holder import SessionContextHolder , Transactional
8+
9+
10+ class TransactionalTestUser (PySpringModel , table = True ):
11+ """Test model for transactional operations"""
12+ id : int = Field (default = None , primary_key = True )
13+ name : str
14+ email : str
15+ age : int = Field (default = 0 )
16+
17+
18+ class TestTransactionalDecorator :
19+ """Test suite for the @Transactional decorator"""
20+
21+ @pytest .fixture (autouse = True )
22+ def setup_and_teardown (self ):
23+ """Set up test environment with in-memory SQLite database"""
24+ self .engine = create_engine ("sqlite:///:memory:" , echo = False )
25+ PySpringModel .set_engine (self .engine )
26+ PySpringModel .set_metadata (SQLModel .metadata )
27+ PySpringModel .set_models ([TransactionalTestUser ])
28+
29+ # Clear any existing session
30+ SessionContextHolder .clear_session ()
31+
32+ SQLModel .metadata .create_all (self .engine )
33+
34+ def teardown_method (self ):
35+ """Tear down test environment"""
36+ SQLModel .metadata .drop_all (self .engine )
37+ SessionContextHolder .clear_session ()
38+ PySpringModel ._engine = None
39+ PySpringModel ._metadata = None
40+ PySpringModel ._connection = None
41+
42+ def test_single_transactional_success (self ):
43+ """Test that a single @Transactional function commits successfully"""
44+
45+ @Transactional
46+ def create_user ():
47+ session = SessionContextHolder .get_or_create_session ()
48+ user = TransactionalTestUser (name = "John Doe" , email = "john@example.com" , age = 30 )
49+ session .add (user )
50+ session .flush () # To get the ID
51+ return user
52+
53+ # Execute the transactional function
54+ result = create_user ()
55+
56+ # Verify the user was created and committed
57+ assert result .name == "John Doe"
58+ assert result .email == "john@example.com"
59+ assert result .age == 30
60+
61+ # Verify session is cleared after transaction
62+ assert not SessionContextHolder .has_session ()
63+
64+ # Verify data persisted to database
65+ with PySpringModel .create_managed_session () as session :
66+ users = session .execute (text ("SELECT * FROM transactionaltestuser" )).fetchall ()
67+ assert len (users ) == 1
68+ assert users [0 ].name == "John Doe"
69+
70+ def test_single_transactional_rollback (self ):
71+ """Test that a single @Transactional function rolls back on exception"""
72+
73+ @Transactional
74+ def create_user_with_error ():
75+ session = SessionContextHolder .get_or_create_session ()
76+ user = TransactionalTestUser (name = "Jane Doe" , email = "jane@example.com" , age = 25 )
77+ session .add (user )
78+ session .flush ()
79+ raise ValueError ("Simulated error" )
80+
81+ # Execute the transactional function and expect exception
82+ with pytest .raises (ValueError , match = "Simulated error" ):
83+ create_user_with_error ()
84+
85+ # Verify session is cleared after rollback
86+ assert not SessionContextHolder .has_session ()
87+
88+ # Verify no data persisted to database
89+ with PySpringModel .create_managed_session () as session :
90+ users = session .execute (text ("SELECT * FROM transactionaltestuser" )).fetchall ()
91+ assert len (users ) == 0
92+
93+ def test_nested_transactional_success (self ):
94+ """Test that nested @Transactional functions share the same session and commit at top level"""
95+
96+ @Transactional
97+ def create_user (name : str , email : str ):
98+ session = SessionContextHolder .get_or_create_session ()
99+ user = TransactionalTestUser (name = name , email = email , age = 30 )
100+ session .add (user )
101+ session .flush ()
102+ return user
103+
104+ @Transactional
105+ def update_user_age (user_id : int , new_age : int ):
106+ session = SessionContextHolder .get_or_create_session ()
107+ session .execute (text (f"UPDATE transactionaltestuser SET age = { new_age } WHERE id = { user_id } " ))
108+
109+ @Transactional
110+ def create_and_update_user ():
111+ # This should share the same session with nested calls
112+ user = create_user ("Alice Smith" , "alice@example.com" )
113+ update_user_age (user .id , 35 )
114+ return user
115+
116+ # Execute the outer transactional function
117+ result = create_and_update_user ()
118+
119+ # Verify session is cleared after transaction
120+ assert not SessionContextHolder .has_session ()
121+
122+ # Verify both operations were committed
123+ with PySpringModel .create_managed_session () as session :
124+ users = session .execute (text ("SELECT * FROM transactionaltestuser" )).fetchall ()
125+ assert len (users ) == 1
126+ assert users [0 ].name == "Alice Smith"
127+ assert users [0 ].age == 35 # Updated by nested function
128+
129+ def test_nested_transactional_rollback_from_inner (self ):
130+ """Test that exception in nested @Transactional causes rollback at top level"""
131+
132+ @Transactional
133+ def create_user (name : str , email : str ):
134+ session = SessionContextHolder .get_or_create_session ()
135+ user = TransactionalTestUser (name = name , email = email , age = 30 )
136+ session .add (user )
137+ session .flush ()
138+ return user
139+
140+ @Transactional
141+ def update_user_with_error (user_id : int ):
142+ session = SessionContextHolder .get_or_create_session ()
143+ session .execute (text (f"UPDATE transactionaltestuser SET age = 40 WHERE id = { user_id } " ))
144+ raise RuntimeError ("Update failed" )
145+
146+ @Transactional
147+ def create_and_update_user_with_error ():
148+ user = create_user ("Bob Johnson" , "bob@example.com" )
149+ update_user_with_error (user .id ) # This will raise an exception
150+ return user
151+
152+ # Execute and expect exception
153+ with pytest .raises (RuntimeError , match = "Update failed" ):
154+ create_and_update_user_with_error ()
155+
156+ # Verify session is cleared after rollback
157+ assert not SessionContextHolder .has_session ()
158+
159+ # Verify no data persisted (everything rolled back)
160+ with PySpringModel .create_managed_session () as session :
161+ users = session .execute (text ("SELECT * FROM transactionaltestuser" )).fetchall ()
162+ assert len (users ) == 0
163+
164+ def test_nested_transactional_rollback_from_outer (self ):
165+ """Test that exception in outer @Transactional causes rollback after nested calls"""
166+
167+ @Transactional
168+ def create_user (name : str , email : str ):
169+ session = SessionContextHolder .get_or_create_session ()
170+ user = TransactionalTestUser (name = name , email = email , age = 30 )
171+ session .add (user )
172+ session .flush ()
173+ return user
174+
175+ @Transactional
176+ def update_user_age (user_id : int , new_age : int ):
177+ session = SessionContextHolder .get_or_create_session ()
178+ session .execute (text (f"UPDATE transactionaltestuser SET age = { new_age } WHERE id = { user_id } " ))
179+
180+ @Transactional
181+ def create_update_and_fail ():
182+ user = create_user ("Charlie Brown" , "charlie@example.com" )
183+ update_user_age (user .id , 45 )
184+ # Both nested operations succeeded, but outer fails
185+ raise Exception ("Outer operation failed" )
186+
187+ # Execute and expect exception
188+ with pytest .raises (Exception , match = "Outer operation failed" ):
189+ create_update_and_fail ()
190+
191+ # Verify session is cleared after rollback
192+ assert not SessionContextHolder .has_session ()
193+
194+ # Verify no data persisted (everything rolled back)
195+ with PySpringModel .create_managed_session () as session :
196+ users = session .execute (text ("SELECT * FROM transactionaltestuser" )).fetchall ()
197+ assert len (users ) == 0
198+
199+ def test_session_sharing_across_nested_transactions (self ):
200+ """Test that nested @Transactional functions share the same session instance"""
201+
202+ captured_sessions = []
203+
204+ @Transactional
205+ def inner_function ():
206+ session = SessionContextHolder .get_or_create_session ()
207+ captured_sessions .append (session )
208+ user = TransactionalTestUser (name = "Inner User" , email = "inner@example.com" , age = 25 )
209+ session .add (user )
210+ session .flush ()
211+ return session
212+
213+ @Transactional
214+ def middle_function ():
215+ session = SessionContextHolder .get_or_create_session ()
216+ captured_sessions .append (session )
217+ inner_session = inner_function ()
218+ return session , inner_session
219+
220+ @Transactional
221+ def outer_function ():
222+ session = SessionContextHolder .get_or_create_session ()
223+ captured_sessions .append (session )
224+ middle_session , inner_session = middle_function ()
225+ return session , middle_session , inner_session
226+
227+ # Execute nested transactions
228+ outer_session , middle_session , inner_session = outer_function ()
229+
230+ # Verify all functions used the same session instance
231+ assert len (captured_sessions ) == 3
232+ assert captured_sessions [0 ] is captured_sessions [1 ] # outer and middle
233+ assert captured_sessions [1 ] is captured_sessions [2 ] # middle and inner
234+ assert outer_session is middle_session is inner_session
235+
236+ # Verify session is cleared after transaction
237+ assert not SessionContextHolder .has_session ()
238+
239+ def test_transactional_session_context_isolation (self ):
240+ """Test that @Transactional properly isolates session context"""
241+
242+ @Transactional
243+ def first_transaction ():
244+ session = SessionContextHolder .get_or_create_session ()
245+ user1 = TransactionalTestUser (name = "User 1" , email = "user1@example.com" , age = 30 )
246+ session .add (user1 )
247+ session .flush ()
248+ return user1 .id
249+
250+ @Transactional
251+ def second_transaction ():
252+ session = SessionContextHolder .get_or_create_session ()
253+ user2 = TransactionalTestUser (name = "User 2" , email = "user2@example.com" , age = 35 )
254+ session .add (user2 )
255+ session .flush ()
256+ return user2 .id
257+
258+ # Execute separate transactions
259+ first_transaction ()
260+ second_transaction ()
261+
262+ # Verify sessions were properly isolated
263+ assert not SessionContextHolder .has_session ()
264+
265+ # Verify both transactions committed separately
266+ with PySpringModel .create_managed_session () as session :
267+ users = session .execute (text ("SELECT * FROM transactionaltestuser ORDER BY id" )).fetchall ()
268+ assert len (users ) == 2
269+ assert users [0 ].name == "User 1"
270+ assert users [1 ].name == "User 2"
271+
272+ def test_transactional_preserves_function_metadata (self ):
273+ """Test that @Transactional preserves original function metadata"""
274+
275+ @Transactional
276+ def documented_function (param1 : str , param2 : int = 10 ) -> str :
277+ """This is a documented function with parameters."""
278+ return f"{ param1 } _{ param2 } "
279+
280+ # Verify function metadata is preserved
281+ assert documented_function .__name__ == "documented_function"
282+ assert documented_function .__doc__ is not None and "documented function" in documented_function .__doc__
283+
284+ # Verify function still works correctly
285+ result = documented_function ("test" , 20 )
286+ assert result == "test_20"
287+
288+ def test_transactional_commit_rollback_behavior (self ):
289+ """Test the core commit/rollback behavior of nested transactions"""
290+
291+ commit_calls = []
292+ rollback_calls = []
293+
294+ # Mock session to track commit/rollback calls
295+ original_session_class = PySpringModel .create_session
296+
297+ def mock_create_session ():
298+ session = original_session_class ()
299+ original_commit = session .commit
300+ original_rollback = session .rollback
301+
302+ def mock_commit ():
303+ commit_calls .append ("commit" )
304+ return original_commit ()
305+
306+ def mock_rollback ():
307+ rollback_calls .append ("rollback" )
308+ return original_rollback ()
309+
310+ session .commit = mock_commit
311+ session .rollback = mock_rollback
312+ return session
313+
314+ @Transactional
315+ def inner_operation ():
316+ session = SessionContextHolder .get_or_create_session ()
317+ user = TransactionalTestUser (name = "Test" , email = "test@example.com" , age = 30 )
318+ session .add (user )
319+ session .flush ()
320+
321+ @Transactional
322+ def outer_operation ():
323+ inner_operation ()
324+
325+ # Test successful nested transaction
326+ with patch .object (PySpringModel , 'create_session' , side_effect = mock_create_session ):
327+ outer_operation ()
328+
329+ # Only the outermost transaction should commit
330+ assert len (commit_calls ) == 1
331+ assert len (rollback_calls ) == 0
332+
333+ # Reset counters
334+ commit_calls .clear ()
335+ rollback_calls .clear ()
336+
337+ @Transactional
338+ def inner_operation_with_error ():
339+ session = SessionContextHolder .get_or_create_session ()
340+ user = TransactionalTestUser (name = "Test2" , email = "test2@example.com" , age = 25 )
341+ session .add (user )
342+ session .flush ()
343+ raise ValueError ("Inner error" )
344+
345+ @Transactional
346+ def outer_operation_with_error ():
347+ inner_operation_with_error ()
348+
349+ # Test failed nested transaction
350+ with patch .object (PySpringModel , 'create_session' , side_effect = mock_create_session ):
351+ with pytest .raises (ValueError ):
352+ outer_operation_with_error ()
353+
354+ # Only the outermost transaction should rollback
355+ assert len (commit_calls ) == 0
356+ assert len (rollback_calls ) == 1
0 commit comments