| 
16 | 16 | from sqlmodel.sql.expression import Select, SelectOfScalar  | 
17 | 17 | 
 
  | 
18 | 18 | from py_spring_model.core.model import PySpringModel  | 
 | 19 | +from py_spring_model.core.session_context_holder import SessionContextHolder, Transactional  | 
19 | 20 | from py_spring_model.repository.repository_base import RepositoryBase  | 
20 | 21 | 
 
  | 
21 | 22 | T = TypeVar("T", bound=PySpringModel)  | 
@@ -61,120 +62,124 @@ def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]:  | 
61 | 62 |     def _find_by_statement(  | 
62 | 63 |         self,  | 
63 | 64 |         statement: Union[Select, SelectOfScalar],  | 
64 |  | -        session: Optional[Session] = None,  | 
65 |  | -    ) -> tuple[Session, Optional[T]]:  | 
66 |  | -        session = session or self._create_session()  | 
 | 65 | +    ) -> Optional[T]:  | 
 | 66 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 67 | + | 
 | 68 | +        return session.exec(statement).first()  | 
67 | 69 | 
 
  | 
68 |  | -        return session, session.exec(statement).first()  | 
69 | 70 | 
 
  | 
70 | 71 |     def _find_by_query(  | 
71 | 72 |         self,  | 
72 | 73 |         query_by: dict[str, Any],  | 
73 |  | -        session: Optional[Session] = None,  | 
74 |  | -    ) -> tuple[Session, Optional[T]]:  | 
75 |  | -        session = session or self._create_session()  | 
 | 74 | +    ) -> Optional[T]:  | 
 | 75 | +        session = SessionContextHolder.get_or_create_session()  | 
76 | 76 |         statement = select(self.model_class).filter_by(**query_by)  | 
77 |  | -        return session, session.exec(statement).first()  | 
 | 77 | +        return session.exec(statement).first()  | 
78 | 78 | 
 
  | 
79 | 79 |     def _find_all_by_query(  | 
80 | 80 |         self,  | 
81 | 81 |         query_by: dict[str, Any],  | 
82 |  | -        session: Optional[Session] = None,  | 
83 | 82 |     ) -> tuple[Session, list[T]]:  | 
84 |  | -        session = session or self._create_session()  | 
 | 83 | +        session = SessionContextHolder.get_or_create_session()  | 
85 | 84 |         statement = select(self.model_class).filter_by(**query_by)  | 
86 | 85 |         return session, list(session.exec(statement).fetchall())  | 
87 | 86 | 
 
  | 
88 | 87 |     def _find_all_by_statement(  | 
89 | 88 |         self,  | 
90 | 89 |         statement: Union[Select, SelectOfScalar],  | 
91 |  | -        session: Optional[Session] = None,  | 
92 |  | -    ) -> tuple[Session, list[T]]:  | 
93 |  | -        session = session or self._create_session()  | 
94 |  | -        return session, list(session.exec(statement).fetchall())  | 
 | 90 | +    ) -> list[T]:  | 
 | 91 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 92 | +        return list(session.exec(statement).fetchall())  | 
95 | 93 | 
 
  | 
96 | 94 |     def find_by_id(self, id: ID) -> Optional[T]:  | 
97 |  | -        with self.create_managed_session() as session:  | 
98 |  | -            statement = select(self.model_class).where(self.model_class.id == id)  # type: ignore  | 
99 |  | -            optional_entity = session.exec(statement).first()  | 
100 |  | -            if optional_entity is None:  | 
101 |  | -                return  | 
102 |  | - | 
103 |  | -            return optional_entity.clone()  # type: ignore  | 
 | 95 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 96 | +        statement = select(self.model_class).where(self.model_class.id == id)  # type: ignore  | 
 | 97 | +        optional_entity = session.exec(statement).first()  | 
 | 98 | +        if optional_entity is None:  | 
 | 99 | +            return  | 
104 | 100 | 
 
  | 
 | 101 | +        return optional_entity  | 
105 | 102 |     def find_all_by_ids(self, ids: list[ID]) -> list[T]:  | 
106 |  | -        with self.create_managed_session() as session:  | 
107 |  | -            statement = select(self.model_class).where(self.model_class.id.in_(ids))  # type: ignore  | 
108 |  | -            return [entity.clone() for entity in session.exec(statement).all()]  # type: ignore  | 
 | 103 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 104 | +        statement = select(self.model_class).where(self.model_class.id.in_(ids))  # type: ignore  | 
 | 105 | +        return [entity for entity in session.exec(statement).all()]  # type: ignore  | 
109 | 106 | 
 
  | 
110 | 107 |     def find_all(self) -> list[T]:  | 
111 |  | -        with self.create_managed_session() as session:  | 
112 |  | -            statement = select(self.model_class)  # type: ignore  | 
113 |  | -            return [entity.clone() for entity in session.exec(statement).all()]  # type: ignore  | 
 | 108 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 109 | +        statement = select(self.model_class)  # type: ignore  | 
 | 110 | +        return [entity for entity in session.exec(statement).all()]  # type: ignore  | 
114 | 111 | 
 
  | 
 | 112 | +    @Transactional  | 
115 | 113 |     def save(self, entity: T) -> T:  | 
116 |  | -        with self.create_managed_session() as session:  | 
117 |  | -            session.add(entity)  | 
118 |  | -        return entity.clone()  # type: ignore  | 
 | 114 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 115 | +        session.add(entity)  | 
 | 116 | +        return entity  | 
119 | 117 | 
 
  | 
 | 118 | +    @Transactional  | 
120 | 119 |     def save_all(  | 
121 | 120 |         self,  | 
122 | 121 |         entities: Iterable[T],  | 
123 | 122 |     ) -> bool:  | 
124 |  | -        with self.create_managed_session() as session:  | 
125 |  | -            session.add_all(entities)  | 
 | 123 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 124 | +        session.add_all(entities)  | 
126 | 125 |         return True  | 
127 | 126 | 
 
  | 
 | 127 | +    @Transactional  | 
128 | 128 |     def delete(self, entity: T) -> bool:  | 
129 |  | -        with self.create_managed_session() as session:  | 
130 |  | -            _, optional_intance = self._find_by_query(entity.model_dump(), session)  | 
131 |  | -            if optional_intance is None:  | 
132 |  | -                return False  | 
133 |  | -            session.delete(optional_intance)  | 
 | 129 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 130 | +        optional_intance = self._find_by_query(entity.model_dump())  | 
 | 131 | +        if optional_intance is None:  | 
 | 132 | +            return False  | 
 | 133 | +        session.delete(optional_intance)  | 
134 | 134 |         return True  | 
135 | 135 | 
 
  | 
 | 136 | +    @Transactional  | 
136 | 137 |     def delete_all(self, entities: Iterable[T]) -> bool:  | 
137 |  | -        with self.create_managed_session() as session:  | 
138 |  | -            ids = [entity.id for entity in entities] # type: ignore  | 
139 |  | -              | 
140 |  | -            statement = select(self.model_class).where(self.model_class.id.in_(ids))  # type: ignore  | 
141 |  | -            _, deleted_entities = self._find_all_by_statement(statement, session)  | 
142 |  | -            if deleted_entities is None:  | 
143 |  | -                return False  | 
144 |  | -              | 
145 |  | -            for entity in deleted_entities:  | 
146 |  | -                session.delete(entity)  | 
 | 138 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 139 | +        ids = [entity.id for entity in entities] # type: ignore  | 
 | 140 | +          | 
 | 141 | +        statement = select(self.model_class).where(self.model_class.id.in_(ids))  # type: ignore  | 
 | 142 | +        deleted_entities = self._find_all_by_statement(statement)  | 
 | 143 | +        if deleted_entities is None:  | 
 | 144 | +            return False  | 
 | 145 | +          | 
 | 146 | +        for entity in deleted_entities:  | 
 | 147 | +            session.delete(entity)  | 
147 | 148 | 
 
  | 
148 | 149 |         return True  | 
149 | 150 | 
 
  | 
 | 151 | + | 
 | 152 | +    @Transactional  | 
150 | 153 |     def delete_by_id(self, _id: ID) -> bool:  | 
151 |  | -        with self.create_managed_session() as session:  | 
152 |  | -            _, entity = self._find_by_query({"id": _id}, session)  | 
153 |  | -            if entity is None:  | 
154 |  | -                return False  | 
155 |  | -            session.delete(entity)  | 
 | 154 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 155 | +        entity = self._find_by_query({"id": _id})  | 
 | 156 | +        if entity is None:  | 
 | 157 | +            return False  | 
 | 158 | +        session.delete(entity)  | 
156 | 159 |         return True  | 
157 | 160 | 
 
  | 
 | 161 | +    @Transactional  | 
158 | 162 |     def delete_all_by_ids(self, ids: list[ID]) -> bool:  | 
159 |  | -        with self.create_managed_session() as session:  | 
160 |  | -            statement = select(self.model_class).where(self.model_class.id.in_(ids))  # type: ignore  | 
161 |  | -            _, deleted_entities = self._find_all_by_statement(statement, session)  | 
162 |  | -            if deleted_entities is None:  | 
163 |  | -                return False  | 
164 |  | -            for entity in deleted_entities:  | 
165 |  | -                session.delete(entity)  | 
166 |  | -            return True  | 
 | 163 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 164 | +        statement = select(self.model_class).where(self.model_class.id.in_(ids))  # type: ignore  | 
 | 165 | +        deleted_entities = self._find_all_by_statement(statement)  | 
 | 166 | +        if deleted_entities is None:  | 
 | 167 | +            return False  | 
 | 168 | +        for entity in deleted_entities:  | 
 | 169 | +            session.delete(entity)  | 
 | 170 | +        return True  | 
167 | 171 | 
 
  | 
 | 172 | +    @Transactional  | 
168 | 173 |     def upsert(self, entity: T, query_by: dict[str, Any]) -> T:  | 
169 |  | -        with self.create_managed_session() as session:  | 
170 |  | -            statement = select(self.model_class).filter_by(**query_by)  # type: ignore  | 
171 |  | -            _, existing_entity = self._find_by_statement(statement, session)  | 
172 |  | -            if existing_entity is not None:  | 
173 |  | -                # If the entity exists, update its attributes  | 
174 |  | -                for key, value in entity.model_dump().items():  | 
175 |  | -                    setattr(existing_entity, key, value)  | 
176 |  | -                session.add(existing_entity)  | 
177 |  | -            else:  | 
178 |  | -                # If the entity does not exist, insert it  | 
179 |  | -                session.add(entity)  | 
180 |  | -            return entity  | 
 | 174 | +        session = SessionContextHolder.get_or_create_session()  | 
 | 175 | +        statement = select(self.model_class).filter_by(**query_by)  # type: ignore  | 
 | 176 | +        existing_entity = self._find_by_statement(statement)  | 
 | 177 | +        if existing_entity is not None:  | 
 | 178 | +            # If the entity exists, update its attributes  | 
 | 179 | +            for key, value in entity.model_dump().items():  | 
 | 180 | +                setattr(existing_entity, key, value)  | 
 | 181 | +            session.add(existing_entity)  | 
 | 182 | +        else:  | 
 | 183 | +            # If the entity does not exist, insert it  | 
 | 184 | +            session.add(entity)  | 
 | 185 | +        return entity  | 
0 commit comments