Skip to content

Commit da4ccde

Browse files
authored
Add Modifiable Query Support with Commit Control (#9)
1 parent b8ad9dc commit da4ccde

File tree

3 files changed

+723
-6
lines changed

3 files changed

+723
-6
lines changed

py_spring_model/core/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def create_session(cls) -> PySpringSession:
9191

9292
@classmethod
9393
@contextlib.contextmanager
94-
def create_managed_session(cls) -> Iterator[PySpringSession]:
94+
def create_managed_session(cls, should_commit: bool = True) -> Iterator[PySpringSession]:
9595
"""
9696
Creates a managed session context that will automatically close the session when the context is exited.
9797
## Example Syntax:
@@ -103,7 +103,8 @@ def create_managed_session(cls) -> Iterator[PySpringSession]:
103103
session = cls.create_session()
104104
yield session
105105
logger.debug("[MANAGED SESSION COMMIT] Session committing...")
106-
session.commit()
106+
if should_commit:
107+
session.commit()
107108
logger.debug(
108109
"[MANAGED SESSION COMMIT] Session committed, refreshing instances..."
109110
)

py_spring_model/py_spring_model_rest/service/query_service/query.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323

2424
class QueryExecutionService:
2525
@classmethod
26-
def execute_query(cls, query_template: str, func: Callable[P, RT], kwargs: dict) -> RT:
26+
def execute_query(cls,
27+
query_template: str,
28+
func: Callable[P, RT],
29+
kwargs: dict,
30+
is_modifying: bool
31+
) -> RT:
2732
RETURN = "return"
2833

2934
annotations = func.__annotations__
@@ -42,7 +47,7 @@ def execute_query(cls, query_template: str, func: Callable[P, RT], kwargs: dict)
4247
processed_kwargs = cls._process_kwargs(kwargs)
4348

4449
sql = query_template.format(**processed_kwargs)
45-
with PySpringModel.create_session() as session:
50+
with PySpringModel.create_managed_session(should_commit=is_modifying) as session:
4651
reutrn_origin = get_origin(return_type)
4752
return_args = get_args(return_type)
4853

@@ -95,7 +100,7 @@ def _validate_return_type(cls, actual_type, return_type):
95100
def _process_single_result(cls, result: Row, actual_type: Type[BaseModel]) -> Optional[BaseModel]:
96101
return actual_type.model_validate(result._asdict())
97102

98-
def Query(query_template: str) -> Callable[[Callable[P, RT]], Callable[P, RT]]:
103+
def Query(query_template: str, is_modifying: bool = False) -> Callable[[Callable[P, RT]], Callable[P, RT]]:
99104
"""
100105
Decorator to mark a method as a query method.
101106
The method will be implemented automatically by the `CrudRepositoryImplementationService`.
@@ -130,6 +135,6 @@ def decorator(func: Callable[P, RT]) -> Callable[P, RT]:
130135
@functools.wraps(func)
131136
def wrapper(*args: P.args, **kwargs: P.kwargs) -> RT:
132137
nonlocal query_template
133-
return QueryExecutionService.execute_query(query_template, func, kwargs)
138+
return QueryExecutionService.execute_query(query_template, func, kwargs, is_modifying)
134139
return wrapper
135140
return decorator

0 commit comments

Comments
 (0)