@@ -93,33 +93,128 @@ def _implemenmt_query(self, repository_type: Type[CrudRepository]) -> None:
93
93
if RETURN_KEY in copy_annotations :
94
94
copy_annotations .pop (RETURN_KEY )
95
95
96
- if len (copy_annotations ) != len (query .required_fields ) or set (
97
- copy_annotations .keys ()
98
- ) != set (query .required_fields ):
96
+ # Create parameter to field mapping for better API design
97
+ param_to_field_mapping = self ._create_parameter_field_mapping (
98
+ list (copy_annotations .keys ()), query .required_fields
99
+ )
100
+
101
+ if len (copy_annotations ) != len (query .required_fields ):
99
102
raise ValueError (
100
103
f"Invalid number of annotations. Expected { query .required_fields } , received { list (copy_annotations .keys ())} ."
101
104
)
105
+
102
106
# Create a wrapper for the current method and query
103
- wrapped_method = self .create_implementation_wrapper (query , model_type , copy_annotations )
107
+ wrapped_method = self .create_implementation_wrapper (query , model_type , copy_annotations , param_to_field_mapping )
104
108
logger .info (
105
109
f"Binding method: { method } to { repository_type } , with query: { query } "
106
110
)
107
111
setattr (repository_type , method , wrapped_method )
108
112
109
- def create_implementation_wrapper (self , query : _Query , model_type : Type [PySpringModel ], original_func_annotations : dict [str , Any ]) -> Callable [..., Any ]:
113
+ def _create_parameter_field_mapping (self , param_names : list [str ], field_names : list [str ]) -> dict [str , str ]:
114
+ """
115
+ Create a mapping between parameter names and field names.
116
+ This allows for more readable API design where parameter names can be plural
117
+ while still mapping to singular field names.
118
+
119
+ The method validates that parameter names correspond to field names and provides
120
+ clear error messages for mismatches.
121
+
122
+ Examples:
123
+ - param_names: ['names'], field_names: ['name'] -> {'names': 'name'}
124
+ - param_names: ['ages'], field_names: ['age'] -> {'ages': 'age'}
125
+ - param_names: ['name', 'age'], field_names: ['name', 'age'] -> {'name': 'name', 'age': 'age'}
126
+ """
127
+ if len (param_names ) != len (field_names ):
128
+ raise ValueError (
129
+ f"Parameter count mismatch. Expected { len (field_names )} parameters for fields { field_names } , "
130
+ f"but got { len (param_names )} parameters: { param_names } "
131
+ )
132
+
133
+ mapping = {}
134
+ unmatched_params = []
135
+ unmatched_fields = []
136
+
137
+ # Create a set of field names for efficient lookup
138
+ field_set = set (field_names )
139
+
140
+ for param_name in param_names :
141
+ # Try exact match first
142
+ if param_name in field_set :
143
+ mapping [param_name ] = param_name
144
+ continue
145
+
146
+ # Try singular/plural variations
147
+ singular_match = None
148
+ plural_match = None
149
+
150
+ # Check if param_name is plural and field_name is singular
151
+ if param_name .endswith ('s' ) and len (param_name ) > 1 :
152
+ singular_candidate = param_name [:- 1 ]
153
+ if singular_candidate in field_set :
154
+ singular_match = singular_candidate
155
+
156
+ # Check if param_name is singular and field_name is plural
157
+ elif not param_name .endswith ('s' ):
158
+ plural_candidate = param_name + 's'
159
+ if plural_candidate in field_set :
160
+ plural_match = plural_candidate
161
+
162
+ # Use the best match found
163
+ if singular_match :
164
+ mapping [param_name ] = singular_match
165
+ elif plural_match :
166
+ mapping [param_name ] = plural_match
167
+ else :
168
+ unmatched_params .append (param_name )
169
+
170
+ # Check for unmatched fields
171
+ mapped_fields = set (mapping .values ())
172
+ for field_name in field_names :
173
+ if field_name not in mapped_fields :
174
+ unmatched_fields .append (field_name )
175
+
176
+ # Report any mismatches
177
+ if unmatched_params or unmatched_fields :
178
+ error_msg = "Parameter to field mapping failed:\n "
179
+ if unmatched_params :
180
+ error_msg += f" Unmatched parameters: { unmatched_params } \n "
181
+ if unmatched_fields :
182
+ error_msg += f" Unmatched fields: { unmatched_fields } \n "
183
+ error_msg += f" Available fields: { field_names } \n "
184
+ error_msg += f" Provided parameters: { param_names } "
185
+ raise ValueError (error_msg )
186
+
187
+ return mapping
188
+
189
+ def create_implementation_wrapper (self , query : _Query , model_type : Type [PySpringModel ], original_func_annotations : dict [str , Any ], param_to_field_mapping : dict [str , str ]) -> Callable [..., Any ]:
110
190
def wrapper (* args , ** kwargs ) -> Any :
111
191
if len (query .required_fields ) > 0 :
112
- # Check if all required fields are present in kwargs
113
- if set (query .required_fields ) != set (kwargs .keys ()):
192
+ # Map parameter names to field names
193
+ field_kwargs = {}
194
+ for param_name , value in kwargs .items ():
195
+ if param_name in param_to_field_mapping :
196
+ field_name = param_to_field_mapping [param_name ]
197
+ field_kwargs [field_name ] = value
198
+ else :
199
+ # Fallback: use parameter name as field name
200
+ field_kwargs [param_name ] = value
201
+
202
+ # Check if all required fields are present
203
+ if set (query .required_fields ) != set (field_kwargs .keys ()):
114
204
raise ValueError (
115
- f"Invalid number of keyword arguments. Expected { query .required_fields } , received { kwargs } ."
205
+ f"Invalid number of keyword arguments. Expected { query .required_fields } , received { list ( kwargs . keys ()) } ."
116
206
)
117
207
118
- # Execute the query
119
- sql_statement = self ._get_sql_statement (model_type , query , kwargs )
120
- result = self ._session_execute (sql_statement , query .is_one_result )
121
- logger .info (f"Executing query with params: { kwargs } " )
122
- return result
208
+ # Execute the query with mapped field names
209
+ sql_statement = self ._get_sql_statement (model_type , query , field_kwargs )
210
+ result = self ._session_execute (sql_statement , query .is_one_result )
211
+ logger .info (f"Executing query with params: { kwargs } " )
212
+ return result
213
+ else :
214
+ # No required fields, execute without parameters
215
+ sql_statement = self ._get_sql_statement (model_type , query , {})
216
+ result = self ._session_execute (sql_statement , query .is_one_result )
217
+ return result
123
218
124
219
wrapper .__annotations__ = original_func_annotations
125
220
return wrapper
0 commit comments