1- import inspect
2- from typing import Iterable , Type , cast
1+ from typing import Type , cast
32
4- import py_spring_core .core .utils as core_utils
53from loguru import logger
64from py_spring_core import Component , EntityProvider , ApplicationContextRequired
75from sqlalchemy import create_engine
8- from sqlalchemy .exc import InvalidRequestError as SqlAlehemyInvalidRequestError
96from sqlmodel import SQLModel
107
11- from py_spring_model .core .commons import ApplicationFileGroups , PySpringModelProperties
8+ from py_spring_model .core .commons import PySpringModelProperties
129from py_spring_model .core .model import PySpringModel
1310from py_spring_model .py_spring_model_rest .controller .session_controller import SessionController
1411from py_spring_model .repository .repository_base import RepositoryBase
1512from py_spring_model .py_spring_model_rest import PySpringModelRestService
16- from py_spring_model .py_spring_model_rest .controller .py_spring_model_rest_controller import (
17- PySpringModelRestController ,
18- )
1913from py_spring_model .py_spring_model_rest .service .curd_repository_implementation_service .crud_repository_implementation_service import (
2014 CrudRepositoryImplementationService ,
2115)
@@ -43,94 +37,17 @@ def _get_props(self) -> PySpringModelProperties:
4337 assert props is not None
4438 return props
4539
46- def _group_file_paths (self , files : Iterable [str ]) -> ApplicationFileGroups :
47- props = self ._get_props ()
48-
49- class_files : set [str ] = set ()
50- model_files : set [str ] = set ()
51-
52- for file in files :
53- py_file_name = self ._get_file_base_name (file )
54- if py_file_name in props .model_file_postfix_patterns :
55- model_files .add (file )
56- if file not in model_files :
57- class_files .add (file )
58- return ApplicationFileGroups (class_files = class_files , model_files = model_files )
59-
60- def _import_model_modules (self ) -> None :
61- logger .info (
62- f"[SQLMODEL TABLE MODEL IMPORT] Import all models: { self .app_file_groups .model_files } "
63- )
64-
65- def import_func_wrapper () -> set [type [object ]]:
66- return core_utils .dynamically_import_modules (
67- self .app_file_groups .model_files ,
68- is_ignore_error = False ,
69- target_subclasses = [PySpringModel , SQLModel ],
70- )
71-
72- try :
73- self ._model_classes = import_func_wrapper ()
74- except SqlAlehemyInvalidRequestError as error :
75- logger .warning (
76- f"[ERROR ADVISE] Encounter { error .__class__ .__name__ } when importing model classes."
77- )
78- logger .error (
79- f"[SQLMODEL TABLE MODEL IMPORT FAILED] Failed to import model modules: { error } "
80- )
81- self ._model_classes = self ._get_pyspring_model_inheritors ()
82-
83- def _is_from_model_file (self , cls : Type [object ]) -> bool :
84- props = self ._get_props ()
85- try :
86- source_file_name = inspect .getsourcefile (cls )
87- except TypeError as error :
88- logger .warning (
89- f"[CHECK MODEL FILE] Failed to get source file name for class: { cls .__name__ } , largely due to built-in classes.\n Actual error: { error } "
90- )
91- return False
92- if source_file_name is None :
93- return False
94- py_file_name = self ._get_file_base_name (source_file_name ) # e.g., models.py
95- return py_file_name in props .model_file_postfix_patterns
96-
97- def _get_file_base_name (self , file_path : str ) -> str :
98- return file_path .split ("/" )[- 1 ]
99-
10040 def _get_pyspring_model_inheritors (self ) -> set [Type [object ]]:
10141 # use dict to store all models, use a session to check if all models are mapped
10242 class_name_with_class_map : dict [str , Type [object ]] = {}
10343 for _cls in set (PySpringModel .__subclasses__ ()):
10444 if _cls .__name__ in class_name_with_class_map :
10545 continue
106- if not self ._is_from_model_file (_cls ):
107- logger .warning (
108- f"[SQLMODEL TABLE MODEL IMPORT] { _cls .__name__ } is not from model file, skip it."
109- )
110- continue
111-
11246 class_name_with_class_map [_cls .__name__ ] = _cls
11347
11448 return set (class_name_with_class_map .values ())
115-
49+
11650 def _create_all_tables (self ) -> None :
117- props = self ._get_props ()
118-
119-
120- PySpringModel .set_engine (self .sql_engine )
121- PySpringModel .set_models (
122- cast (list [Type [PySpringModel ]], list (self ._model_classes ))
123- )
124- PySpringModel .set_metadata (SQLModel .metadata )
125- RepositoryBase .engine = self .sql_engine
126- RepositoryBase .connection = self .sql_engine .connect ()
127-
128- if not props .create_all_tables :
129- logger .info ("[SQLMODEL TABLE CREATION] Skip creating all tables, set create_all_tables to True to enable." )
130- return
131- if not props .create_all_tables :
132- return
133-
13451 table_names = SQLModel .metadata .tables .keys ()
13552 logger .success (
13653 f"[SQLMODEL TABLE CREATION] Create all SQLModel tables, engine url: { self .sql_engine .url } , tables: { ', ' .join (table_names )} "
@@ -140,18 +57,30 @@ def _create_all_tables(self) -> None:
14057 f"[SQLMODEL TABLE MODEL IMPORT] Get model classes from PySpringModel inheritors: { ', ' .join ([_cls .__name__ for _cls in self ._model_classes ])} "
14158 )
14259
60+
61+ def _init_pyspring_model (self ) -> None :
62+ self ._model_classes = self ._get_pyspring_model_inheritors ()
63+ PySpringModel .set_engine (self .sql_engine )
64+ PySpringModel .set_models (
65+ cast (list [Type [PySpringModel ]], list (self ._model_classes ))
66+ )
67+ PySpringModel .set_metadata (SQLModel .metadata )
68+ RepositoryBase .engine = self .sql_engine
69+ RepositoryBase .connection = self .sql_engine .connect ()
70+
14371 def provider_init (self ) -> None :
14472 props = self ._get_props ()
14573 logger .info (
14674 f"[PYSPRING MODEL PROVIDER INIT] Initialize PySpringModelProvider with app context: { self .app_context } "
14775 )
148- app_context = self .get_application_context ()
149-
150- self .app_file_groups = self ._group_file_paths (app_context .all_file_paths )
15176 self .sql_engine = create_engine (
15277 url = props .sqlalchemy_database_uri , echo = True
15378 )
154- self ._import_model_modules ()
79+ self ._init_pyspring_model ()
80+ props = self ._get_props ()
81+ if not props .create_all_tables :
82+ logger .info ("[SQLMODEL TABLE CREATION] Skip creating all tables, set create_all_tables to True to enable." )
83+ return
15584 self ._create_all_tables ()
15685
15786
0 commit comments