Skip to content

Commit a7f4a65

Browse files
Implement Session Management and Transactional Decorator for CRUD Operations
- Introduced `SessionContextHolder` to manage SQLAlchemy sessions using context variables. - Added `Transactional` decorator to ensure session commits and rollbacks during CRUD operations. - Refactored `CrudRepository` methods to utilize the new session management, enhancing code clarity and reducing session handling boilerplate. - Updated tests to reflect changes in the repository methods, ensuring proper functionality without session parameters.
1 parent da4ccde commit a7f4a65

File tree

4 files changed

+130
-74
lines changed

4 files changed

+130
-74
lines changed

py_spring_model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from py_spring_model.core.model import PySpringModel, Field
2+
from py_spring_model.core.session_context_holder import SessionContextHolder
23
from py_spring_model.py_spring_model_provider import provide_py_spring_model
34
from py_spring_model.repository.crud_repository import CrudRepository
45
from py_spring_model.repository.repository_base import RepositoryBase
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from contextvars import ContextVar
2+
from functools import wraps
3+
from typing import ClassVar, Optional
4+
from sqlmodel import Session
5+
6+
from py_spring_model.core.model import PySpringModel
7+
8+
def Transactional(func):
9+
"""
10+
A decorator that wraps a function and commits the session if the function is successful.
11+
If the function raises an exception, the session is rolled back.
12+
The session is then closed.
13+
"""
14+
@wraps(func)
15+
def wrapper(*args, **kwargs):
16+
session = SessionContextHolder.get_or_create_session()
17+
try:
18+
result = func(*args, **kwargs)
19+
session.commit()
20+
return result
21+
except Exception:
22+
session.rollback()
23+
raise
24+
finally:
25+
SessionContextHolder.clear_session()
26+
return wrapper
27+
28+
class SessionContextHolder:
29+
"""
30+
A context holder for the session.
31+
This is used to store the session in a context variable so that it can be accessed by the query service.
32+
This is useful for the query service to be able to access the session without having to pass it in as an argument.
33+
This is also useful for the query service to be able to access the session without having to pass it in as an argument.
34+
"""
35+
_session: ClassVar[ContextVar[Optional[Session]]] = ContextVar("session", default=None)
36+
@classmethod
37+
def get_or_create_session(cls) -> Session:
38+
optional_session = cls._session.get()
39+
if optional_session is None:
40+
session = PySpringModel.create_session()
41+
cls._session.set(session)
42+
return session
43+
return optional_session
44+
@classmethod
45+
def clear_session(cls):
46+
optional_session = cls._session.get()
47+
if optional_session is None:
48+
return
49+
optional_session.close()
50+
cls._session.set(None)

py_spring_model/repository/crud_repository.py

Lines changed: 77 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sqlmodel.sql.expression import Select, SelectOfScalar
1717

1818
from py_spring_model.core.model import PySpringModel
19+
from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional
1920
from py_spring_model.repository.repository_base import RepositoryBase
2021

2122
T = TypeVar("T", bound=PySpringModel)
@@ -61,120 +62,124 @@ def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]:
6162
def _find_by_statement(
6263
self,
6364
statement: Union[Select, SelectOfScalar],
64-
session: Optional[Session] = None,
65-
) -> tuple[Session, Optional[T]]:
66-
session = session or self._create_session()
65+
) -> Optional[T]:
66+
session = SessionContextHolder.get_or_create_session()
67+
68+
return session.exec(statement).first()
6769

68-
return session, session.exec(statement).first()
6970

7071
def _find_by_query(
7172
self,
7273
query_by: dict[str, Any],
73-
session: Optional[Session] = None,
74-
) -> tuple[Session, Optional[T]]:
75-
session = session or self._create_session()
74+
) -> Optional[T]:
75+
session = SessionContextHolder.get_or_create_session()
7676
statement = select(self.model_class).filter_by(**query_by)
77-
return session, session.exec(statement).first()
77+
return session.exec(statement).first()
7878

7979
def _find_all_by_query(
8080
self,
8181
query_by: dict[str, Any],
82-
session: Optional[Session] = None,
8382
) -> tuple[Session, list[T]]:
84-
session = session or self._create_session()
83+
session = SessionContextHolder.get_or_create_session()
8584
statement = select(self.model_class).filter_by(**query_by)
8685
return session, list(session.exec(statement).fetchall())
8786

8887
def _find_all_by_statement(
8988
self,
9089
statement: Union[Select, SelectOfScalar],
91-
session: Optional[Session] = None,
92-
) -> tuple[Session, list[T]]:
93-
session = session or self._create_session()
94-
return session, list(session.exec(statement).fetchall())
90+
) -> list[T]:
91+
session = SessionContextHolder.get_or_create_session()
92+
return list(session.exec(statement).fetchall())
9593

9694
def find_by_id(self, id: ID) -> Optional[T]:
97-
with self.create_managed_session() as session:
98-
statement = select(self.model_class).where(self.model_class.id == id) # type: ignore
99-
optional_entity = session.exec(statement).first()
100-
if optional_entity is None:
101-
return
102-
103-
return optional_entity.clone() # type: ignore
95+
session = SessionContextHolder.get_or_create_session()
96+
statement = select(self.model_class).where(self.model_class.id == id) # type: ignore
97+
optional_entity = session.exec(statement).first()
98+
if optional_entity is None:
99+
return
104100

101+
return optional_entity
105102
def find_all_by_ids(self, ids: list[ID]) -> list[T]:
106-
with self.create_managed_session() as session:
107-
statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore
108-
return [entity.clone() for entity in session.exec(statement).all()] # type: ignore
103+
session = SessionContextHolder.get_or_create_session()
104+
statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore
105+
return [entity for entity in session.exec(statement).all()] # type: ignore
109106

110107
def find_all(self) -> list[T]:
111-
with self.create_managed_session() as session:
112-
statement = select(self.model_class) # type: ignore
113-
return [entity.clone() for entity in session.exec(statement).all()] # type: ignore
108+
session = SessionContextHolder.get_or_create_session()
109+
statement = select(self.model_class) # type: ignore
110+
return [entity for entity in session.exec(statement).all()] # type: ignore
114111

112+
@Transactional
115113
def save(self, entity: T) -> T:
116-
with self.create_managed_session() as session:
117-
session.add(entity)
118-
return entity.clone() # type: ignore
114+
session = SessionContextHolder.get_or_create_session()
115+
session.add(entity)
116+
return entity
119117

118+
@Transactional
120119
def save_all(
121120
self,
122121
entities: Iterable[T],
123122
) -> bool:
124-
with self.create_managed_session() as session:
125-
session.add_all(entities)
123+
session = SessionContextHolder.get_or_create_session()
124+
session.add_all(entities)
126125
return True
127126

127+
@Transactional
128128
def delete(self, entity: T) -> bool:
129-
with self.create_managed_session() as session:
130-
_, optional_intance = self._find_by_query(entity.model_dump(), session)
131-
if optional_intance is None:
132-
return False
133-
session.delete(optional_intance)
129+
session = SessionContextHolder.get_or_create_session()
130+
optional_intance = self._find_by_query(entity.model_dump())
131+
if optional_intance is None:
132+
return False
133+
session.delete(optional_intance)
134134
return True
135135

136+
@Transactional
136137
def delete_all(self, entities: Iterable[T]) -> bool:
137-
with self.create_managed_session() as session:
138-
ids = [entity.id for entity in entities] # type: ignore
139-
140-
statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore
141-
_, deleted_entities = self._find_all_by_statement(statement, session)
142-
if deleted_entities is None:
143-
return False
144-
145-
for entity in deleted_entities:
146-
session.delete(entity)
138+
session = SessionContextHolder.get_or_create_session()
139+
ids = [entity.id for entity in entities] # type: ignore
140+
141+
statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore
142+
deleted_entities = self._find_all_by_statement(statement)
143+
if deleted_entities is None:
144+
return False
145+
146+
for entity in deleted_entities:
147+
session.delete(entity)
147148

148149
return True
149150

151+
152+
@Transactional
150153
def delete_by_id(self, _id: ID) -> bool:
151-
with self.create_managed_session() as session:
152-
_, entity = self._find_by_query({"id": _id}, session)
153-
if entity is None:
154-
return False
155-
session.delete(entity)
154+
session = SessionContextHolder.get_or_create_session()
155+
entity = self._find_by_query({"id": _id})
156+
if entity is None:
157+
return False
158+
session.delete(entity)
156159
return True
157160

161+
@Transactional
158162
def delete_all_by_ids(self, ids: list[ID]) -> bool:
159-
with self.create_managed_session() as session:
160-
statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore
161-
_, deleted_entities = self._find_all_by_statement(statement, session)
162-
if deleted_entities is None:
163-
return False
164-
for entity in deleted_entities:
165-
session.delete(entity)
166-
return True
163+
session = SessionContextHolder.get_or_create_session()
164+
statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore
165+
deleted_entities = self._find_all_by_statement(statement)
166+
if deleted_entities is None:
167+
return False
168+
for entity in deleted_entities:
169+
session.delete(entity)
170+
return True
167171

172+
@Transactional
168173
def upsert(self, entity: T, query_by: dict[str, Any]) -> T:
169-
with self.create_managed_session() as session:
170-
statement = select(self.model_class).filter_by(**query_by) # type: ignore
171-
_, existing_entity = self._find_by_statement(statement, session)
172-
if existing_entity is not None:
173-
# If the entity exists, update its attributes
174-
for key, value in entity.model_dump().items():
175-
setattr(existing_entity, key, value)
176-
session.add(existing_entity)
177-
else:
178-
# If the entity does not exist, insert it
179-
session.add(entity)
180-
return entity
174+
session = SessionContextHolder.get_or_create_session()
175+
statement = select(self.model_class).filter_by(**query_by) # type: ignore
176+
existing_entity = self._find_by_statement(statement)
177+
if existing_entity is not None:
178+
# If the entity exists, update its attributes
179+
for key, value in entity.model_dump().items():
180+
setattr(existing_entity, key, value)
181+
session.add(existing_entity)
182+
else:
183+
# If the entity does not exist, insert it
184+
session.add(entity)
185+
return entity

tests/test_crud_repository.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def test_find_all(self, user_repository: UserRepository):
5757

5858
def test_find_by_query(self, user_repository: UserRepository):
5959
self.create_test_user(user_repository)
60-
_, user = user_repository._find_by_query({"name": "John Doe"})
60+
user = user_repository._find_by_query({"name": "John Doe"})
6161
assert user is not None
6262
assert user.id == 1
6363
assert user.name == "John Doe"
6464

65-
_, email_user = user_repository._find_by_query({"email": "john@example.com"})
65+
email_user = user_repository._find_by_query({"email": "john@example.com"})
6666
assert email_user is not None
6767
assert email_user.id == 1
6868
assert email_user.email == "john@example.com"

0 commit comments

Comments
 (0)