Skip to content

Commit 86e3562

Browse files
feat: Enhance CRUD repository with @transactional decorator for improved transaction management
- Added @transactional decorator to multiple CRUD methods in CrudRepository to ensure proper transaction handling. - Updated commit logic in PySpringSession to warn against committing managed transactions. - Refactored clone method in PySpringModel to return the correct type. - Improved logging during session commit in PySpringModel for better clarity. - Introduced a method in SessionContextHolder to check if a transaction is managed.
1 parent b95fe27 commit 86e3562

File tree

5 files changed

+41
-12
lines changed

5 files changed

+41
-12
lines changed

py_spring_model/core/model.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextlib
2-
from typing_extensions import Self
32
from typing import ClassVar, Iterator, Optional, Type
43
from loguru import logger
54
from sqlalchemy import Engine, MetaData
@@ -81,7 +80,7 @@ def get_model_lookup(cls) -> dict[str, type["PySpringModel"]]:
8180
raise ValueError("[MODEL_LOOKUP NOT SET] Model lookup is not set")
8281
return {str(_model.__tablename__): _model for _model in cls._models}
8382

84-
def clone(self) -> Self:
83+
def clone(self) -> "PySpringModel":
8584
return self.model_validate_json(self.model_dump_json())
8685

8786
@classmethod
@@ -105,11 +104,7 @@ def create_managed_session(cls, should_commit: bool = True) -> Iterator[PySpring
105104
logger.debug("[MANAGED SESSION COMMIT] Session committing...")
106105
if should_commit:
107106
session.commit()
108-
logger.debug(
109-
"[MANAGED SESSION COMMIT] Session committed, refreshing instances..."
110-
)
111-
session.refresh_current_session_instances()
112-
logger.success("[MANAGED SESSION COMMIT] Session committed.")
107+
logger.success("[MANAGED SESSION COMMIT] Session committed.")
113108
except Exception as error:
114109
logger.error(error)
115110
logger.error("[MANAGED SESSION ROLLBACK] Session rolling back...")

py_spring_model/core/py_spring_session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Iterable
22

3+
from loguru import logger
34
from sqlmodel import Session, SQLModel
45

56

@@ -33,3 +34,12 @@ def add_all(self, instances: Iterable[SQLModel]) -> None:
3334
def refresh_current_session_instances(self) -> None:
3435
for instance in self.current_session_instance:
3536
self.refresh(instance)
37+
38+
def commit(self) -> None:
39+
# Import here to avoid circular import
40+
from py_spring_model.core.session_context_holder import SessionContextHolder
41+
if SessionContextHolder.is_transaction_managed():
42+
logger.warning("Commiting a transaction that is currently being managed by the outermost transaction is strongly discouraged...")
43+
return
44+
super().commit()
45+
self.refresh_current_session_instances()

py_spring_model/core/session_context_holder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
from functools import wraps
44
from typing import Any, Callable, ClassVar, Optional
55

6-
from py_spring_model.core.py_spring_session import PySpringSession
7-
86
from py_spring_model.core.model import PySpringModel
7+
from py_spring_model.core.py_spring_session import PySpringSession
98

109
class TransactionalDepth(IntEnum):
1110
OUTERMOST = 1
@@ -131,4 +130,8 @@ def clear_session(cls):
131130
if session is not None:
132131
session.close()
133132
cls._session.set(None)
134-
cls._session_depth.set(TransactionalDepth.ON_EXIT.value)
133+
cls._session_depth.set(TransactionalDepth.ON_EXIT.value)
134+
135+
@classmethod
136+
def is_transaction_managed(cls) -> bool:
137+
return cls._session_depth.get() > TransactionalDepth.OUTERMOST.value

py_spring_model/repository/crud_repository.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(self) -> None:
5959
def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]:
6060
return get_args(tp=cls.__mro__[0].__orig_bases__[0])
6161

62+
@Transactional
6263
def _find_by_statement(
6364
self,
6465
statement: Union[Select, SelectOfScalar],
@@ -67,7 +68,7 @@ def _find_by_statement(
6768

6869
return session.exec(statement).first()
6970

70-
71+
@Transactional
7172
def _find_by_query(
7273
self,
7374
query_by: dict[str, Any],
@@ -76,6 +77,8 @@ def _find_by_query(
7677
statement = select(self.model_class).filter_by(**query_by)
7778
return session.exec(statement).first()
7879

80+
81+
@Transactional
7982
def _find_all_by_query(
8083
self,
8184
query_by: dict[str, Any],
@@ -84,13 +87,15 @@ def _find_all_by_query(
8487
statement = select(self.model_class).filter_by(**query_by)
8588
return session, list(session.exec(statement).fetchall())
8689

90+
@Transactional
8791
def _find_all_by_statement(
8892
self,
8993
statement: Union[Select, SelectOfScalar],
9094
) -> list[T]:
9195
session = SessionContextHolder.get_or_create_session()
9296
return list(session.exec(statement).fetchall())
9397

98+
@Transactional
9499
def find_by_id(self, id: ID) -> Optional[T]:
95100
session = SessionContextHolder.get_or_create_session()
96101
statement = select(self.model_class).where(self.model_class.id == id) # type: ignore
@@ -99,11 +104,14 @@ def find_by_id(self, id: ID) -> Optional[T]:
99104
return
100105

101106
return optional_entity
107+
108+
@Transactional
102109
def find_all_by_ids(self, ids: list[ID]) -> list[T]:
103110
session = SessionContextHolder.get_or_create_session()
104111
statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore
105112
return [entity for entity in session.exec(statement).all()] # type: ignore
106113

114+
@Transactional
107115
def find_all(self) -> list[T]:
108116
session = SessionContextHolder.get_or_create_session()
109117
statement = select(self.model_class) # type: ignore

tests/test_crud_repository.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,17 @@ def test_upsert_for_new_user(self, user_repository: UserRepository):
135135
new_user = user_repository.find_by_id(1)
136136
assert new_user is not None
137137
assert new_user.name == "John Doe"
138-
assert new_user.email == "john@example.com"
138+
assert new_user.email == "john@example.com"
139+
140+
def test_update_user(self, user_repository: UserRepository):
141+
self.create_test_user(user_repository)
142+
user = user_repository.find_by_id(1)
143+
assert user is not None
144+
user.name = "William Chen"
145+
user.email = "william.chen@example.com"
146+
user_repository.save(user)
147+
updated_user = user_repository.find_by_id(1)
148+
assert updated_user is not None
149+
assert updated_user.name == "William Chen"
150+
assert updated_user.email == "william.chen@example.com"
151+

0 commit comments

Comments
 (0)