|
16 | 16 | from sqlmodel.sql.expression import Select, SelectOfScalar |
17 | 17 |
|
18 | 18 | from py_spring_model.core.model import PySpringModel |
| 19 | +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional |
19 | 20 | from py_spring_model.repository.repository_base import RepositoryBase |
20 | 21 |
|
21 | 22 | T = TypeVar("T", bound=PySpringModel) |
@@ -61,120 +62,124 @@ def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]: |
61 | 62 | def _find_by_statement( |
62 | 63 | self, |
63 | 64 | statement: Union[Select, SelectOfScalar], |
64 | | - session: Optional[Session] = None, |
65 | | - ) -> tuple[Session, Optional[T]]: |
66 | | - session = session or self._create_session() |
| 65 | + ) -> Optional[T]: |
| 66 | + session = SessionContextHolder.get_or_create_session() |
| 67 | + |
| 68 | + return session.exec(statement).first() |
67 | 69 |
|
68 | | - return session, session.exec(statement).first() |
69 | 70 |
|
70 | 71 | def _find_by_query( |
71 | 72 | self, |
72 | 73 | query_by: dict[str, Any], |
73 | | - session: Optional[Session] = None, |
74 | | - ) -> tuple[Session, Optional[T]]: |
75 | | - session = session or self._create_session() |
| 74 | + ) -> Optional[T]: |
| 75 | + session = SessionContextHolder.get_or_create_session() |
76 | 76 | statement = select(self.model_class).filter_by(**query_by) |
77 | | - return session, session.exec(statement).first() |
| 77 | + return session.exec(statement).first() |
78 | 78 |
|
79 | 79 | def _find_all_by_query( |
80 | 80 | self, |
81 | 81 | query_by: dict[str, Any], |
82 | | - session: Optional[Session] = None, |
83 | 82 | ) -> tuple[Session, list[T]]: |
84 | | - session = session or self._create_session() |
| 83 | + session = SessionContextHolder.get_or_create_session() |
85 | 84 | statement = select(self.model_class).filter_by(**query_by) |
86 | 85 | return session, list(session.exec(statement).fetchall()) |
87 | 86 |
|
88 | 87 | def _find_all_by_statement( |
89 | 88 | self, |
90 | 89 | statement: Union[Select, SelectOfScalar], |
91 | | - session: Optional[Session] = None, |
92 | | - ) -> tuple[Session, list[T]]: |
93 | | - session = session or self._create_session() |
94 | | - return session, list(session.exec(statement).fetchall()) |
| 90 | + ) -> list[T]: |
| 91 | + session = SessionContextHolder.get_or_create_session() |
| 92 | + return list(session.exec(statement).fetchall()) |
95 | 93 |
|
96 | 94 | def find_by_id(self, id: ID) -> Optional[T]: |
97 | | - with self.create_managed_session() as session: |
98 | | - statement = select(self.model_class).where(self.model_class.id == id) # type: ignore |
99 | | - optional_entity = session.exec(statement).first() |
100 | | - if optional_entity is None: |
101 | | - return |
102 | | - |
103 | | - return optional_entity.clone() # type: ignore |
| 95 | + session = SessionContextHolder.get_or_create_session() |
| 96 | + statement = select(self.model_class).where(self.model_class.id == id) # type: ignore |
| 97 | + optional_entity = session.exec(statement).first() |
| 98 | + if optional_entity is None: |
| 99 | + return |
104 | 100 |
|
| 101 | + return optional_entity |
105 | 102 | def find_all_by_ids(self, ids: list[ID]) -> list[T]: |
106 | | - with self.create_managed_session() as session: |
107 | | - statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore |
108 | | - return [entity.clone() for entity in session.exec(statement).all()] # type: ignore |
| 103 | + session = SessionContextHolder.get_or_create_session() |
| 104 | + statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore |
| 105 | + return [entity for entity in session.exec(statement).all()] # type: ignore |
109 | 106 |
|
110 | 107 | def find_all(self) -> list[T]: |
111 | | - with self.create_managed_session() as session: |
112 | | - statement = select(self.model_class) # type: ignore |
113 | | - return [entity.clone() for entity in session.exec(statement).all()] # type: ignore |
| 108 | + session = SessionContextHolder.get_or_create_session() |
| 109 | + statement = select(self.model_class) # type: ignore |
| 110 | + return [entity for entity in session.exec(statement).all()] # type: ignore |
114 | 111 |
|
| 112 | + @Transactional |
115 | 113 | def save(self, entity: T) -> T: |
116 | | - with self.create_managed_session() as session: |
117 | | - session.add(entity) |
118 | | - return entity.clone() # type: ignore |
| 114 | + session = SessionContextHolder.get_or_create_session() |
| 115 | + session.add(entity) |
| 116 | + return entity |
119 | 117 |
|
| 118 | + @Transactional |
120 | 119 | def save_all( |
121 | 120 | self, |
122 | 121 | entities: Iterable[T], |
123 | 122 | ) -> bool: |
124 | | - with self.create_managed_session() as session: |
125 | | - session.add_all(entities) |
| 123 | + session = SessionContextHolder.get_or_create_session() |
| 124 | + session.add_all(entities) |
126 | 125 | return True |
127 | 126 |
|
| 127 | + @Transactional |
128 | 128 | def delete(self, entity: T) -> bool: |
129 | | - with self.create_managed_session() as session: |
130 | | - _, optional_intance = self._find_by_query(entity.model_dump(), session) |
131 | | - if optional_intance is None: |
132 | | - return False |
133 | | - session.delete(optional_intance) |
| 129 | + session = SessionContextHolder.get_or_create_session() |
| 130 | + optional_intance = self._find_by_query(entity.model_dump()) |
| 131 | + if optional_intance is None: |
| 132 | + return False |
| 133 | + session.delete(optional_intance) |
134 | 134 | return True |
135 | 135 |
|
| 136 | + @Transactional |
136 | 137 | def delete_all(self, entities: Iterable[T]) -> bool: |
137 | | - with self.create_managed_session() as session: |
138 | | - ids = [entity.id for entity in entities] # type: ignore |
139 | | - |
140 | | - statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore |
141 | | - _, deleted_entities = self._find_all_by_statement(statement, session) |
142 | | - if deleted_entities is None: |
143 | | - return False |
144 | | - |
145 | | - for entity in deleted_entities: |
146 | | - session.delete(entity) |
| 138 | + session = SessionContextHolder.get_or_create_session() |
| 139 | + ids = [entity.id for entity in entities] # type: ignore |
| 140 | + |
| 141 | + statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore |
| 142 | + deleted_entities = self._find_all_by_statement(statement) |
| 143 | + if deleted_entities is None: |
| 144 | + return False |
| 145 | + |
| 146 | + for entity in deleted_entities: |
| 147 | + session.delete(entity) |
147 | 148 |
|
148 | 149 | return True |
149 | 150 |
|
| 151 | + |
| 152 | + @Transactional |
150 | 153 | def delete_by_id(self, _id: ID) -> bool: |
151 | | - with self.create_managed_session() as session: |
152 | | - _, entity = self._find_by_query({"id": _id}, session) |
153 | | - if entity is None: |
154 | | - return False |
155 | | - session.delete(entity) |
| 154 | + session = SessionContextHolder.get_or_create_session() |
| 155 | + entity = self._find_by_query({"id": _id}) |
| 156 | + if entity is None: |
| 157 | + return False |
| 158 | + session.delete(entity) |
156 | 159 | return True |
157 | 160 |
|
| 161 | + @Transactional |
158 | 162 | def delete_all_by_ids(self, ids: list[ID]) -> bool: |
159 | | - with self.create_managed_session() as session: |
160 | | - statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore |
161 | | - _, deleted_entities = self._find_all_by_statement(statement, session) |
162 | | - if deleted_entities is None: |
163 | | - return False |
164 | | - for entity in deleted_entities: |
165 | | - session.delete(entity) |
166 | | - return True |
| 163 | + session = SessionContextHolder.get_or_create_session() |
| 164 | + statement = select(self.model_class).where(self.model_class.id.in_(ids)) # type: ignore |
| 165 | + deleted_entities = self._find_all_by_statement(statement) |
| 166 | + if deleted_entities is None: |
| 167 | + return False |
| 168 | + for entity in deleted_entities: |
| 169 | + session.delete(entity) |
| 170 | + return True |
167 | 171 |
|
| 172 | + @Transactional |
168 | 173 | def upsert(self, entity: T, query_by: dict[str, Any]) -> T: |
169 | | - with self.create_managed_session() as session: |
170 | | - statement = select(self.model_class).filter_by(**query_by) # type: ignore |
171 | | - _, existing_entity = self._find_by_statement(statement, session) |
172 | | - if existing_entity is not None: |
173 | | - # If the entity exists, update its attributes |
174 | | - for key, value in entity.model_dump().items(): |
175 | | - setattr(existing_entity, key, value) |
176 | | - session.add(existing_entity) |
177 | | - else: |
178 | | - # If the entity does not exist, insert it |
179 | | - session.add(entity) |
180 | | - return entity |
| 174 | + session = SessionContextHolder.get_or_create_session() |
| 175 | + statement = select(self.model_class).filter_by(**query_by) # type: ignore |
| 176 | + existing_entity = self._find_by_statement(statement) |
| 177 | + if existing_entity is not None: |
| 178 | + # If the entity exists, update its attributes |
| 179 | + for key, value in entity.model_dump().items(): |
| 180 | + setattr(existing_entity, key, value) |
| 181 | + session.add(existing_entity) |
| 182 | + else: |
| 183 | + # If the entity does not exist, insert it |
| 184 | + session.add(entity) |
| 185 | + return entity |
0 commit comments