Skip to content

Commit b22e3db

Browse files
feat: Implement explicit session depth tracking in SessionContextHolder
- Added TransactionalDepth enum to define transaction levels. - Enhanced SessionContextHolder with methods to manage session depth. - Updated @transactional decorator to utilize session depth for commit/rollback operations. - Introduced tests for session depth tracking and behavior in nested transactions.
1 parent 8e3efc1 commit b22e3db

File tree

2 files changed

+165
-9
lines changed

2 files changed

+165
-9
lines changed

py_spring_model/core/session_context_holder.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from contextvars import ContextVar
2+
from enum import IntEnum
23
from functools import wraps
34
from typing import Any, Callable, ClassVar, Optional
45

56
from py_spring_model.core.py_spring_session import PySpringSession
67

78
from py_spring_model.core.model import PySpringModel
89

10+
class TransactionalDepth(IntEnum):
11+
OUTERMOST = 1
12+
ON_EXIT = 0
13+
914
def Transactional(func: Callable[..., Any]) -> Callable[..., Any]:
1015
"""
1116
Decorator for managing database transactions in a nested-safe manner.
@@ -45,30 +50,34 @@ def update_account():
4550
"""
4651
@wraps(func)
4752
def wrapper(*args, **kwargs):
48-
is_outermost_transaction = not SessionContextHolder.has_session()
53+
# Increment session depth and get session
54+
session_depth = SessionContextHolder.enter_session()
4955
session = SessionContextHolder.get_or_create_session()
5056
try:
5157
result = func(*args, **kwargs)
52-
if is_outermost_transaction:
58+
# Only commit at the outermost level (session_depth == 1)
59+
if session_depth == TransactionalDepth.OUTERMOST.value:
5360
session.commit()
5461
return result
5562
except Exception as error:
56-
if is_outermost_transaction:
63+
# Only rollback at the outermost level (session_depth == 1)
64+
if session_depth == TransactionalDepth.OUTERMOST.value:
5765
session.rollback()
5866
raise error
5967
finally:
60-
if is_outermost_transaction:
61-
SessionContextHolder.clear_session()
68+
# Decrement depth and clean up session if needed
69+
SessionContextHolder.exit_session()
6270
return wrapper
6371

6472
class SessionContextHolder:
6573
"""
66-
A context holder for the session.
74+
A context holder for the session with explicit depth tracking.
6775
This is used to store the session in a context variable so that it can be accessed by the query service.
68-
This is useful for the query service to be able to access the session without having to pass it in as an argument.
69-
This is also useful for the query service to be able to access the session without having to pass it in as an argument.
76+
The depth counter ensures that only the outermost transaction manages commit/rollback operations.
7077
"""
7178
_session: ClassVar[ContextVar[Optional[PySpringSession]]] = ContextVar("session", default=None)
79+
_session_depth: ClassVar[ContextVar[int]] = ContextVar("session_depth", default=0)
80+
7281
@classmethod
7382
def get_or_create_session(cls) -> PySpringSession:
7483
optional_session = cls._session.get()
@@ -82,9 +91,44 @@ def get_or_create_session(cls) -> PySpringSession:
8291
def has_session(cls) -> bool:
8392
return cls._session.get() is not None
8493

94+
@classmethod
95+
def get_session_depth(cls) -> int:
96+
"""Get the current session depth."""
97+
return cls._session_depth.get()
98+
99+
@classmethod
100+
def enter_session(cls) -> int:
101+
"""
102+
Enter a new session context and increment the depth counter.
103+
Returns the new depth level.
104+
"""
105+
current_depth = cls._session_depth.get()
106+
new_depth = current_depth + 1
107+
cls._session_depth.set(new_depth)
108+
return new_depth
109+
110+
@classmethod
111+
def exit_session(cls) -> int:
112+
"""
113+
Exit the current session context and decrement the depth counter.
114+
If depth reaches 0, clear the session.
115+
Returns the new depth level.
116+
"""
117+
current_depth = cls._session_depth.get()
118+
new_depth = max(0, current_depth - 1) # Prevent negative depth
119+
cls._session_depth.set(new_depth)
120+
121+
# Clear session only when depth reaches 0 (outermost level)
122+
if new_depth == 0:
123+
cls.clear_session()
124+
125+
return new_depth
126+
85127
@classmethod
86128
def clear_session(cls):
129+
"""Clear the session and reset depth to 0."""
87130
session = cls._session.get()
88131
if session is not None:
89132
session.close()
90-
cls._session.set(None)
133+
cls._session.set(None)
134+
cls._session_depth.set(TransactionalDepth.ON_EXIT.value)

tests/test_session_depth.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import pytest
2+
from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional
3+
from py_spring_model.core.model import PySpringModel
4+
5+
6+
class TestSessionDepth:
7+
"""Test the explicit session depth tracking functionality"""
8+
9+
def setup_method(self):
10+
"""Clean up any existing sessions before each test"""
11+
SessionContextHolder.clear_session()
12+
13+
def teardown_method(self):
14+
"""Clean up after each test"""
15+
SessionContextHolder.clear_session()
16+
17+
def test_session_depth_starts_at_zero(self):
18+
"""Test that session depth starts at 0"""
19+
assert SessionContextHolder.get_session_depth() == 0
20+
21+
def test_session_depth_increments_and_decrements(self):
22+
"""Test that session depth properly increments and decrements"""
23+
# Initially 0
24+
assert SessionContextHolder.get_session_depth() == 0
25+
26+
# Enter first level
27+
depth1 = SessionContextHolder.enter_session()
28+
assert depth1 == 1
29+
assert SessionContextHolder.get_session_depth() == 1
30+
31+
# Enter second level
32+
depth2 = SessionContextHolder.enter_session()
33+
assert depth2 == 2
34+
assert SessionContextHolder.get_session_depth() == 2
35+
36+
# Exit second level
37+
depth_after_exit1 = SessionContextHolder.exit_session()
38+
assert depth_after_exit1 == 1
39+
assert SessionContextHolder.get_session_depth() == 1
40+
41+
# Exit first level
42+
depth_after_exit2 = SessionContextHolder.exit_session()
43+
assert depth_after_exit2 == 0
44+
assert SessionContextHolder.get_session_depth() == 0
45+
46+
def test_session_cleared_only_at_outermost_level(self):
47+
"""Test that session is only cleared when depth reaches 0"""
48+
# Enter first level and create session
49+
SessionContextHolder.enter_session()
50+
session = SessionContextHolder.get_or_create_session()
51+
assert SessionContextHolder.has_session()
52+
53+
# Enter second level - session should still exist
54+
SessionContextHolder.enter_session()
55+
assert SessionContextHolder.has_session()
56+
assert SessionContextHolder.get_session_depth() == 2
57+
58+
# Exit second level - session should still exist
59+
SessionContextHolder.exit_session()
60+
assert SessionContextHolder.has_session()
61+
assert SessionContextHolder.get_session_depth() == 1
62+
63+
# Exit first level - session should be cleared
64+
SessionContextHolder.exit_session()
65+
assert not SessionContextHolder.has_session()
66+
assert SessionContextHolder.get_session_depth() == 0
67+
68+
def test_clear_session_resets_depth(self):
69+
"""Test that clear_session() resets the depth to 0"""
70+
SessionContextHolder.enter_session()
71+
SessionContextHolder.enter_session()
72+
assert SessionContextHolder.get_session_depth() == 2
73+
74+
SessionContextHolder.clear_session()
75+
assert SessionContextHolder.get_session_depth() == 0
76+
77+
@pytest.mark.parametrize("nesting_levels", [1, 2, 3, 5])
78+
def test_transactional_depth_tracking(self, nesting_levels):
79+
"""Test that @Transactional properly tracks depth at various nesting levels"""
80+
depth_records = []
81+
82+
def create_nested_function(level: int):
83+
@Transactional
84+
def nested_func():
85+
current_depth = SessionContextHolder.get_session_depth()
86+
depth_records.append(current_depth)
87+
if level > 1:
88+
create_nested_function(level - 1)()
89+
return nested_func
90+
91+
# Create and execute nested function
92+
outermost_func = create_nested_function(nesting_levels)
93+
outermost_func()
94+
95+
# Verify depth progression
96+
assert len(depth_records) == nesting_levels
97+
for i, recorded_depth in enumerate(depth_records):
98+
expected_depth = i + 1
99+
assert recorded_depth == expected_depth, f"At level {i+1}, expected depth {expected_depth}, got {recorded_depth}"
100+
101+
# Verify session is cleaned up
102+
assert SessionContextHolder.get_session_depth() == 0
103+
assert not SessionContextHolder.has_session()
104+
105+
def test_depth_prevents_negative_values(self):
106+
"""Test that depth counter prevents going below 0"""
107+
assert SessionContextHolder.get_session_depth() == 0
108+
109+
# Try to exit when already at 0
110+
new_depth = SessionContextHolder.exit_session()
111+
assert new_depth == 0
112+
assert SessionContextHolder.get_session_depth() == 0

0 commit comments

Comments
 (0)