Skip to content

Commit e7656db

Browse files
refactor: Simplify PySpringModelProvider by removing unused methods and enhancing initialization logic for model classes and table creation.
1 parent 8132a37 commit e7656db

File tree

2 files changed

+19
-92
lines changed

2 files changed

+19
-92
lines changed

py_spring_model/core/commons.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@ class PySpringModelProperties(Properties):
2323
This class defines properties specific to the PySpring Model, including:
2424
2525
- `__key__`: The key used to identify this set of properties.
26-
- `model_file_postfix_patterns`: A set of strings representing file name patterns for model files.
2726
- `sqlalchemy_database_uri`: The SQLAlchemy database URI used for the model.
2827
"""
2928

3029
__key__ = "py_spring_model"
31-
model_file_postfix_patterns: set[str]
3230
sqlalchemy_database_uri: str
3331
create_all_tables: bool = True

py_spring_model/py_spring_model_provider.py

Lines changed: 19 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
1-
import inspect
2-
from typing import Iterable, Type, cast
1+
from typing import Type, cast
32

4-
import py_spring_core.core.utils as core_utils
53
from loguru import logger
64
from py_spring_core import Component, EntityProvider, ApplicationContextRequired
75
from sqlalchemy import create_engine
8-
from sqlalchemy.exc import InvalidRequestError as SqlAlehemyInvalidRequestError
96
from sqlmodel import SQLModel
107

11-
from py_spring_model.core.commons import ApplicationFileGroups, PySpringModelProperties
8+
from py_spring_model.core.commons import PySpringModelProperties
129
from py_spring_model.core.model import PySpringModel
1310
from py_spring_model.py_spring_model_rest.controller.session_controller import SessionController
1411
from py_spring_model.repository.repository_base import RepositoryBase
1512
from py_spring_model.py_spring_model_rest import PySpringModelRestService
16-
from py_spring_model.py_spring_model_rest.controller.py_spring_model_rest_controller import (
17-
PySpringModelRestController,
18-
)
1913
from py_spring_model.py_spring_model_rest.service.curd_repository_implementation_service.crud_repository_implementation_service import (
2014
CrudRepositoryImplementationService,
2115
)
@@ -43,94 +37,17 @@ def _get_props(self) -> PySpringModelProperties:
4337
assert props is not None
4438
return props
4539

46-
def _group_file_paths(self, files: Iterable[str]) -> ApplicationFileGroups:
47-
props = self._get_props()
48-
49-
class_files: set[str] = set()
50-
model_files: set[str] = set()
51-
52-
for file in files:
53-
py_file_name = self._get_file_base_name(file)
54-
if py_file_name in props.model_file_postfix_patterns:
55-
model_files.add(file)
56-
if file not in model_files:
57-
class_files.add(file)
58-
return ApplicationFileGroups(class_files=class_files, model_files=model_files)
59-
60-
def _import_model_modules(self) -> None:
61-
logger.info(
62-
f"[SQLMODEL TABLE MODEL IMPORT] Import all models: {self.app_file_groups.model_files}"
63-
)
64-
65-
def import_func_wrapper() -> set[type[object]]:
66-
return core_utils.dynamically_import_modules(
67-
self.app_file_groups.model_files,
68-
is_ignore_error=False,
69-
target_subclasses=[PySpringModel, SQLModel],
70-
)
71-
72-
try:
73-
self._model_classes = import_func_wrapper()
74-
except SqlAlehemyInvalidRequestError as error:
75-
logger.warning(
76-
f"[ERROR ADVISE] Encounter {error.__class__.__name__} when importing model classes."
77-
)
78-
logger.error(
79-
f"[SQLMODEL TABLE MODEL IMPORT FAILED] Failed to import model modules: {error}"
80-
)
81-
self._model_classes = self._get_pyspring_model_inheritors()
82-
83-
def _is_from_model_file(self, cls: Type[object]) -> bool:
84-
props = self._get_props()
85-
try:
86-
source_file_name = inspect.getsourcefile(cls)
87-
except TypeError as error:
88-
logger.warning(
89-
f"[CHECK MODEL FILE] Failed to get source file name for class: {cls.__name__}, largely due to built-in classes.\n Actual error: {error}"
90-
)
91-
return False
92-
if source_file_name is None:
93-
return False
94-
py_file_name = self._get_file_base_name(source_file_name) # e.g., models.py
95-
return py_file_name in props.model_file_postfix_patterns
96-
97-
def _get_file_base_name(self, file_path: str) -> str:
98-
return file_path.split("/")[-1]
99-
10040
def _get_pyspring_model_inheritors(self) -> set[Type[object]]:
10141
# use dict to store all models, use a session to check if all models are mapped
10242
class_name_with_class_map: dict[str, Type[object]] = {}
10343
for _cls in set(PySpringModel.__subclasses__()):
10444
if _cls.__name__ in class_name_with_class_map:
10545
continue
106-
if not self._is_from_model_file(_cls):
107-
logger.warning(
108-
f"[SQLMODEL TABLE MODEL IMPORT] {_cls.__name__} is not from model file, skip it."
109-
)
110-
continue
111-
11246
class_name_with_class_map[_cls.__name__] = _cls
11347

11448
return set(class_name_with_class_map.values())
115-
49+
11650
def _create_all_tables(self) -> None:
117-
props = self._get_props()
118-
119-
120-
PySpringModel.set_engine(self.sql_engine)
121-
PySpringModel.set_models(
122-
cast(list[Type[PySpringModel]], list(self._model_classes))
123-
)
124-
PySpringModel.set_metadata(SQLModel.metadata)
125-
RepositoryBase.engine = self.sql_engine
126-
RepositoryBase.connection = self.sql_engine.connect()
127-
128-
if not props.create_all_tables:
129-
logger.info("[SQLMODEL TABLE CREATION] Skip creating all tables, set create_all_tables to True to enable.")
130-
return
131-
if not props.create_all_tables:
132-
return
133-
13451
table_names = SQLModel.metadata.tables.keys()
13552
logger.success(
13653
f"[SQLMODEL TABLE CREATION] Create all SQLModel tables, engine url: {self.sql_engine.url}, tables: {', '.join(table_names)}"
@@ -140,18 +57,30 @@ def _create_all_tables(self) -> None:
14057
f"[SQLMODEL TABLE MODEL IMPORT] Get model classes from PySpringModel inheritors: {', '.join([_cls.__name__ for _cls in self._model_classes])}"
14158
)
14259

60+
61+
def _init_pyspring_model(self) -> None:
62+
self._model_classes = self._get_pyspring_model_inheritors()
63+
PySpringModel.set_engine(self.sql_engine)
64+
PySpringModel.set_models(
65+
cast(list[Type[PySpringModel]], list(self._model_classes))
66+
)
67+
PySpringModel.set_metadata(SQLModel.metadata)
68+
RepositoryBase.engine = self.sql_engine
69+
RepositoryBase.connection = self.sql_engine.connect()
70+
14371
def provider_init(self) -> None:
14472
props = self._get_props()
14573
logger.info(
14674
f"[PYSPRING MODEL PROVIDER INIT] Initialize PySpringModelProvider with app context: {self.app_context}"
14775
)
148-
app_context = self.get_application_context()
149-
150-
self.app_file_groups = self._group_file_paths(app_context.all_file_paths)
15176
self.sql_engine = create_engine(
15277
url=props.sqlalchemy_database_uri, echo=True
15378
)
154-
self._import_model_modules()
79+
self._init_pyspring_model()
80+
props = self._get_props()
81+
if not props.create_all_tables:
82+
logger.info("[SQLMODEL TABLE CREATION] Skip creating all tables, set create_all_tables to True to enable.")
83+
return
15584
self._create_all_tables()
15685

15786

0 commit comments

Comments
 (0)