@@ -93,33 +93,128 @@ def _implemenmt_query(self, repository_type: Type[CrudRepository]) -> None:
9393 if RETURN_KEY in copy_annotations :
9494 copy_annotations .pop (RETURN_KEY )
9595
96- if len (copy_annotations ) != len (query .required_fields ) or set (
97- copy_annotations .keys ()
98- ) != set (query .required_fields ):
96+ # Create parameter to field mapping for better API design
97+ param_to_field_mapping = self ._create_parameter_field_mapping (
98+ list (copy_annotations .keys ()), query .required_fields
99+ )
100+
101+ if len (copy_annotations ) != len (query .required_fields ):
99102 raise ValueError (
100103 f"Invalid number of annotations. Expected { query .required_fields } , received { list (copy_annotations .keys ())} ."
101104 )
105+
102106 # Create a wrapper for the current method and query
103- wrapped_method = self .create_implementation_wrapper (query , model_type , copy_annotations )
107+ wrapped_method = self .create_implementation_wrapper (query , model_type , copy_annotations , param_to_field_mapping )
104108 logger .info (
105109 f"Binding method: { method } to { repository_type } , with query: { query } "
106110 )
107111 setattr (repository_type , method , wrapped_method )
108112
109- def create_implementation_wrapper (self , query : _Query , model_type : Type [PySpringModel ], original_func_annotations : dict [str , Any ]) -> Callable [..., Any ]:
113+ def _create_parameter_field_mapping (self , param_names : list [str ], field_names : list [str ]) -> dict [str , str ]:
114+ """
115+ Create a mapping between parameter names and field names.
116+ This allows for more readable API design where parameter names can be plural
117+ while still mapping to singular field names.
118+
119+ The method validates that parameter names correspond to field names and provides
120+ clear error messages for mismatches.
121+
122+ Examples:
123+ - param_names: ['names'], field_names: ['name'] -> {'names': 'name'}
124+ - param_names: ['ages'], field_names: ['age'] -> {'ages': 'age'}
125+ - param_names: ['name', 'age'], field_names: ['name', 'age'] -> {'name': 'name', 'age': 'age'}
126+ """
127+ if len (param_names ) != len (field_names ):
128+ raise ValueError (
129+ f"Parameter count mismatch. Expected { len (field_names )} parameters for fields { field_names } , "
130+ f"but got { len (param_names )} parameters: { param_names } "
131+ )
132+
133+ mapping = {}
134+ unmatched_params = []
135+ unmatched_fields = []
136+
137+ # Create a set of field names for efficient lookup
138+ field_set = set (field_names )
139+
140+ for param_name in param_names :
141+ # Try exact match first
142+ if param_name in field_set :
143+ mapping [param_name ] = param_name
144+ continue
145+
146+ # Try singular/plural variations
147+ singular_match = None
148+ plural_match = None
149+
150+ # Check if param_name is plural and field_name is singular
151+ if param_name .endswith ('s' ) and len (param_name ) > 1 :
152+ singular_candidate = param_name [:- 1 ]
153+ if singular_candidate in field_set :
154+ singular_match = singular_candidate
155+
156+ # Check if param_name is singular and field_name is plural
157+ elif not param_name .endswith ('s' ):
158+ plural_candidate = param_name + 's'
159+ if plural_candidate in field_set :
160+ plural_match = plural_candidate
161+
162+ # Use the best match found
163+ if singular_match :
164+ mapping [param_name ] = singular_match
165+ elif plural_match :
166+ mapping [param_name ] = plural_match
167+ else :
168+ unmatched_params .append (param_name )
169+
170+ # Check for unmatched fields
171+ mapped_fields = set (mapping .values ())
172+ for field_name in field_names :
173+ if field_name not in mapped_fields :
174+ unmatched_fields .append (field_name )
175+
176+ # Report any mismatches
177+ if unmatched_params or unmatched_fields :
178+ error_msg = "Parameter to field mapping failed:\n "
179+ if unmatched_params :
180+ error_msg += f" Unmatched parameters: { unmatched_params } \n "
181+ if unmatched_fields :
182+ error_msg += f" Unmatched fields: { unmatched_fields } \n "
183+ error_msg += f" Available fields: { field_names } \n "
184+ error_msg += f" Provided parameters: { param_names } "
185+ raise ValueError (error_msg )
186+
187+ return mapping
188+
189+ def create_implementation_wrapper (self , query : _Query , model_type : Type [PySpringModel ], original_func_annotations : dict [str , Any ], param_to_field_mapping : dict [str , str ]) -> Callable [..., Any ]:
110190 def wrapper (* args , ** kwargs ) -> Any :
111191 if len (query .required_fields ) > 0 :
112- # Check if all required fields are present in kwargs
113- if set (query .required_fields ) != set (kwargs .keys ()):
192+ # Map parameter names to field names
193+ field_kwargs = {}
194+ for param_name , value in kwargs .items ():
195+ if param_name in param_to_field_mapping :
196+ field_name = param_to_field_mapping [param_name ]
197+ field_kwargs [field_name ] = value
198+ else :
199+ # Fallback: use parameter name as field name
200+ field_kwargs [param_name ] = value
201+
202+ # Check if all required fields are present
203+ if set (query .required_fields ) != set (field_kwargs .keys ()):
114204 raise ValueError (
115- f"Invalid number of keyword arguments. Expected { query .required_fields } , received { kwargs } ."
205+ f"Invalid number of keyword arguments. Expected { query .required_fields } , received { list ( kwargs . keys ()) } ."
116206 )
117207
118- # Execute the query
119- sql_statement = self ._get_sql_statement (model_type , query , kwargs )
120- result = self ._session_execute (sql_statement , query .is_one_result )
121- logger .info (f"Executing query with params: { kwargs } " )
122- return result
208+ # Execute the query with mapped field names
209+ sql_statement = self ._get_sql_statement (model_type , query , field_kwargs )
210+ result = self ._session_execute (sql_statement , query .is_one_result )
211+ logger .info (f"Executing query with params: { kwargs } " )
212+ return result
213+ else :
214+ # No required fields, execute without parameters
215+ sql_statement = self ._get_sql_statement (model_type , query , {})
216+ result = self ._session_execute (sql_statement , query .is_one_result )
217+ return result
123218
124219 wrapper .__annotations__ = original_func_annotations
125220 return wrapper
0 commit comments