Skip to content

Commit 8e3efc1

Browse files
test: Add comprehensive tests for @transactional decorator functionality
1 parent 4e7b847 commit 8e3efc1

File tree

1 file changed

+356
-0
lines changed

1 file changed

+356
-0
lines changed

tests/test_transactional_decorator.py

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
import pytest
2+
from unittest.mock import patch
3+
from sqlalchemy import create_engine, text
4+
from sqlmodel import Field, SQLModel
5+
6+
from py_spring_model import PySpringModel
7+
from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional
8+
9+
10+
class TransactionalTestUser(PySpringModel, table=True):
11+
"""Test model for transactional operations"""
12+
id: int = Field(default=None, primary_key=True)
13+
name: str
14+
email: str
15+
age: int = Field(default=0)
16+
17+
18+
class TestTransactionalDecorator:
19+
"""Test suite for the @Transactional decorator"""
20+
21+
@pytest.fixture(autouse=True)
22+
def setup_and_teardown(self):
23+
"""Set up test environment with in-memory SQLite database"""
24+
self.engine = create_engine("sqlite:///:memory:", echo=False)
25+
PySpringModel.set_engine(self.engine)
26+
PySpringModel.set_metadata(SQLModel.metadata)
27+
PySpringModel.set_models([TransactionalTestUser])
28+
29+
# Clear any existing session
30+
SessionContextHolder.clear_session()
31+
32+
SQLModel.metadata.create_all(self.engine)
33+
34+
def teardown_method(self):
35+
"""Tear down test environment"""
36+
SQLModel.metadata.drop_all(self.engine)
37+
SessionContextHolder.clear_session()
38+
PySpringModel._engine = None
39+
PySpringModel._metadata = None
40+
PySpringModel._connection = None
41+
42+
def test_single_transactional_success(self):
43+
"""Test that a single @Transactional function commits successfully"""
44+
45+
@Transactional
46+
def create_user():
47+
session = SessionContextHolder.get_or_create_session()
48+
user = TransactionalTestUser(name="John Doe", email="john@example.com", age=30)
49+
session.add(user)
50+
session.flush() # To get the ID
51+
return user
52+
53+
# Execute the transactional function
54+
result = create_user()
55+
56+
# Verify the user was created and committed
57+
assert result.name == "John Doe"
58+
assert result.email == "john@example.com"
59+
assert result.age == 30
60+
61+
# Verify session is cleared after transaction
62+
assert not SessionContextHolder.has_session()
63+
64+
# Verify data persisted to database
65+
with PySpringModel.create_managed_session() as session:
66+
users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall()
67+
assert len(users) == 1
68+
assert users[0].name == "John Doe"
69+
70+
def test_single_transactional_rollback(self):
71+
"""Test that a single @Transactional function rolls back on exception"""
72+
73+
@Transactional
74+
def create_user_with_error():
75+
session = SessionContextHolder.get_or_create_session()
76+
user = TransactionalTestUser(name="Jane Doe", email="jane@example.com", age=25)
77+
session.add(user)
78+
session.flush()
79+
raise ValueError("Simulated error")
80+
81+
# Execute the transactional function and expect exception
82+
with pytest.raises(ValueError, match="Simulated error"):
83+
create_user_with_error()
84+
85+
# Verify session is cleared after rollback
86+
assert not SessionContextHolder.has_session()
87+
88+
# Verify no data persisted to database
89+
with PySpringModel.create_managed_session() as session:
90+
users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall()
91+
assert len(users) == 0
92+
93+
def test_nested_transactional_success(self):
94+
"""Test that nested @Transactional functions share the same session and commit at top level"""
95+
96+
@Transactional
97+
def create_user(name: str, email: str):
98+
session = SessionContextHolder.get_or_create_session()
99+
user = TransactionalTestUser(name=name, email=email, age=30)
100+
session.add(user)
101+
session.flush()
102+
return user
103+
104+
@Transactional
105+
def update_user_age(user_id: int, new_age: int):
106+
session = SessionContextHolder.get_or_create_session()
107+
session.execute(text(f"UPDATE transactionaltestuser SET age = {new_age} WHERE id = {user_id}"))
108+
109+
@Transactional
110+
def create_and_update_user():
111+
# This should share the same session with nested calls
112+
user = create_user("Alice Smith", "alice@example.com")
113+
update_user_age(user.id, 35)
114+
return user
115+
116+
# Execute the outer transactional function
117+
result = create_and_update_user()
118+
119+
# Verify session is cleared after transaction
120+
assert not SessionContextHolder.has_session()
121+
122+
# Verify both operations were committed
123+
with PySpringModel.create_managed_session() as session:
124+
users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall()
125+
assert len(users) == 1
126+
assert users[0].name == "Alice Smith"
127+
assert users[0].age == 35 # Updated by nested function
128+
129+
def test_nested_transactional_rollback_from_inner(self):
130+
"""Test that exception in nested @Transactional causes rollback at top level"""
131+
132+
@Transactional
133+
def create_user(name: str, email: str):
134+
session = SessionContextHolder.get_or_create_session()
135+
user = TransactionalTestUser(name=name, email=email, age=30)
136+
session.add(user)
137+
session.flush()
138+
return user
139+
140+
@Transactional
141+
def update_user_with_error(user_id: int):
142+
session = SessionContextHolder.get_or_create_session()
143+
session.execute(text(f"UPDATE transactionaltestuser SET age = 40 WHERE id = {user_id}"))
144+
raise RuntimeError("Update failed")
145+
146+
@Transactional
147+
def create_and_update_user_with_error():
148+
user = create_user("Bob Johnson", "bob@example.com")
149+
update_user_with_error(user.id) # This will raise an exception
150+
return user
151+
152+
# Execute and expect exception
153+
with pytest.raises(RuntimeError, match="Update failed"):
154+
create_and_update_user_with_error()
155+
156+
# Verify session is cleared after rollback
157+
assert not SessionContextHolder.has_session()
158+
159+
# Verify no data persisted (everything rolled back)
160+
with PySpringModel.create_managed_session() as session:
161+
users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall()
162+
assert len(users) == 0
163+
164+
def test_nested_transactional_rollback_from_outer(self):
165+
"""Test that exception in outer @Transactional causes rollback after nested calls"""
166+
167+
@Transactional
168+
def create_user(name: str, email: str):
169+
session = SessionContextHolder.get_or_create_session()
170+
user = TransactionalTestUser(name=name, email=email, age=30)
171+
session.add(user)
172+
session.flush()
173+
return user
174+
175+
@Transactional
176+
def update_user_age(user_id: int, new_age: int):
177+
session = SessionContextHolder.get_or_create_session()
178+
session.execute(text(f"UPDATE transactionaltestuser SET age = {new_age} WHERE id = {user_id}"))
179+
180+
@Transactional
181+
def create_update_and_fail():
182+
user = create_user("Charlie Brown", "charlie@example.com")
183+
update_user_age(user.id, 45)
184+
# Both nested operations succeeded, but outer fails
185+
raise Exception("Outer operation failed")
186+
187+
# Execute and expect exception
188+
with pytest.raises(Exception, match="Outer operation failed"):
189+
create_update_and_fail()
190+
191+
# Verify session is cleared after rollback
192+
assert not SessionContextHolder.has_session()
193+
194+
# Verify no data persisted (everything rolled back)
195+
with PySpringModel.create_managed_session() as session:
196+
users = session.execute(text("SELECT * FROM transactionaltestuser")).fetchall()
197+
assert len(users) == 0
198+
199+
def test_session_sharing_across_nested_transactions(self):
200+
"""Test that nested @Transactional functions share the same session instance"""
201+
202+
captured_sessions = []
203+
204+
@Transactional
205+
def inner_function():
206+
session = SessionContextHolder.get_or_create_session()
207+
captured_sessions.append(session)
208+
user = TransactionalTestUser(name="Inner User", email="inner@example.com", age=25)
209+
session.add(user)
210+
session.flush()
211+
return session
212+
213+
@Transactional
214+
def middle_function():
215+
session = SessionContextHolder.get_or_create_session()
216+
captured_sessions.append(session)
217+
inner_session = inner_function()
218+
return session, inner_session
219+
220+
@Transactional
221+
def outer_function():
222+
session = SessionContextHolder.get_or_create_session()
223+
captured_sessions.append(session)
224+
middle_session, inner_session = middle_function()
225+
return session, middle_session, inner_session
226+
227+
# Execute nested transactions
228+
outer_session, middle_session, inner_session = outer_function()
229+
230+
# Verify all functions used the same session instance
231+
assert len(captured_sessions) == 3
232+
assert captured_sessions[0] is captured_sessions[1] # outer and middle
233+
assert captured_sessions[1] is captured_sessions[2] # middle and inner
234+
assert outer_session is middle_session is inner_session
235+
236+
# Verify session is cleared after transaction
237+
assert not SessionContextHolder.has_session()
238+
239+
def test_transactional_session_context_isolation(self):
240+
"""Test that @Transactional properly isolates session context"""
241+
242+
@Transactional
243+
def first_transaction():
244+
session = SessionContextHolder.get_or_create_session()
245+
user1 = TransactionalTestUser(name="User 1", email="user1@example.com", age=30)
246+
session.add(user1)
247+
session.flush()
248+
return user1.id
249+
250+
@Transactional
251+
def second_transaction():
252+
session = SessionContextHolder.get_or_create_session()
253+
user2 = TransactionalTestUser(name="User 2", email="user2@example.com", age=35)
254+
session.add(user2)
255+
session.flush()
256+
return user2.id
257+
258+
# Execute separate transactions
259+
first_transaction()
260+
second_transaction()
261+
262+
# Verify sessions were properly isolated
263+
assert not SessionContextHolder.has_session()
264+
265+
# Verify both transactions committed separately
266+
with PySpringModel.create_managed_session() as session:
267+
users = session.execute(text("SELECT * FROM transactionaltestuser ORDER BY id")).fetchall()
268+
assert len(users) == 2
269+
assert users[0].name == "User 1"
270+
assert users[1].name == "User 2"
271+
272+
def test_transactional_preserves_function_metadata(self):
273+
"""Test that @Transactional preserves original function metadata"""
274+
275+
@Transactional
276+
def documented_function(param1: str, param2: int = 10) -> str:
277+
"""This is a documented function with parameters."""
278+
return f"{param1}_{param2}"
279+
280+
# Verify function metadata is preserved
281+
assert documented_function.__name__ == "documented_function"
282+
assert documented_function.__doc__ is not None and "documented function" in documented_function.__doc__
283+
284+
# Verify function still works correctly
285+
result = documented_function("test", 20)
286+
assert result == "test_20"
287+
288+
def test_transactional_commit_rollback_behavior(self):
289+
"""Test the core commit/rollback behavior of nested transactions"""
290+
291+
commit_calls = []
292+
rollback_calls = []
293+
294+
# Mock session to track commit/rollback calls
295+
original_session_class = PySpringModel.create_session
296+
297+
def mock_create_session():
298+
session = original_session_class()
299+
original_commit = session.commit
300+
original_rollback = session.rollback
301+
302+
def mock_commit():
303+
commit_calls.append("commit")
304+
return original_commit()
305+
306+
def mock_rollback():
307+
rollback_calls.append("rollback")
308+
return original_rollback()
309+
310+
session.commit = mock_commit
311+
session.rollback = mock_rollback
312+
return session
313+
314+
@Transactional
315+
def inner_operation():
316+
session = SessionContextHolder.get_or_create_session()
317+
user = TransactionalTestUser(name="Test", email="test@example.com", age=30)
318+
session.add(user)
319+
session.flush()
320+
321+
@Transactional
322+
def outer_operation():
323+
inner_operation()
324+
325+
# Test successful nested transaction
326+
with patch.object(PySpringModel, 'create_session', side_effect=mock_create_session):
327+
outer_operation()
328+
329+
# Only the outermost transaction should commit
330+
assert len(commit_calls) == 1
331+
assert len(rollback_calls) == 0
332+
333+
# Reset counters
334+
commit_calls.clear()
335+
rollback_calls.clear()
336+
337+
@Transactional
338+
def inner_operation_with_error():
339+
session = SessionContextHolder.get_or_create_session()
340+
user = TransactionalTestUser(name="Test2", email="test2@example.com", age=25)
341+
session.add(user)
342+
session.flush()
343+
raise ValueError("Inner error")
344+
345+
@Transactional
346+
def outer_operation_with_error():
347+
inner_operation_with_error()
348+
349+
# Test failed nested transaction
350+
with patch.object(PySpringModel, 'create_session', side_effect=mock_create_session):
351+
with pytest.raises(ValueError):
352+
outer_operation_with_error()
353+
354+
# Only the outermost transaction should rollback
355+
assert len(commit_calls) == 0
356+
assert len(rollback_calls) == 1

0 commit comments

Comments
 (0)