Skip to content

Commit f9141de

Browse files
committed
support multi value headers and queries
1 parent 34beb4d commit f9141de

File tree

5 files changed

+215
-133
lines changed

5 files changed

+215
-133
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,10 +839,6 @@ def _openapi_operation_parameters(
839839
# Create individual parameter for each model field
840840
param_name = field_def.alias or field_name
841841

842-
# Convert snake_case to kebab-case for headers (HTTP convention)
843-
if isinstance(field_info, Header):
844-
param_name = param_name.replace("_", "-")
845-
846842
individual_param = {
847843
"name": param_name,
848844
"in": field_info.in_.value,

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 46 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import logging
66
from copy import deepcopy
7-
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
7+
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, get_origin
88
from urllib.parse import parse_qs
99

1010
from pydantic import BaseModel
@@ -15,11 +15,12 @@
1515
_normalize_errors,
1616
_regenerate_error_with_loc,
1717
get_missing_field_error,
18+
lenient_issubclass,
1819
)
1920
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
2021
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
2122
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError
22-
from aws_lambda_powertools.event_handler.openapi.params import Header, Param, Query
23+
from aws_lambda_powertools.event_handler.openapi.params import Param
2324

2425
if TYPE_CHECKING:
2526
from aws_lambda_powertools.event_handler import Response
@@ -64,7 +65,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
6465
)
6566

6667
# Normalize query values before validate this
67-
query_string = _normalize_multi_query_string_with_param(
68+
query_string = _normalize_multi_params(
6869
app.current_event.resolved_query_string_parameters,
6970
route.dependant.query_params,
7071
)
@@ -76,7 +77,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
7677
)
7778

7879
# Normalize header values before validate this
79-
headers = _normalize_multi_header_values_with_param(
80+
headers = _normalize_multi_params(
8081
app.current_event.resolved_headers_field,
8182
route.dependant.header_params,
8283
)
@@ -316,38 +317,33 @@ def _request_params_to_args(
316317
received_params: Mapping[str, Any],
317318
) -> tuple[dict[str, Any], list[Any]]:
318319
"""
319-
Convert request params to a dictionary of values with Pydantic model support.
320+
Convert the request params to a dictionary of values using validation, and returns a list of errors.
320321
"""
321322
values = {}
322323
errors = []
323324

324325
for field in required_params:
325326
field_info = field.field_info
326327

327-
# Check if this is a Pydantic model in Query/Header
328-
from pydantic import BaseModel
329-
330-
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
331-
332-
if isinstance(field_info, (Query, Header)) and lenient_issubclass(field_info.annotation, BaseModel):
333-
pass
334-
elif isinstance(field_info, Param):
335-
pass
336-
else:
328+
# To ensure early failure, we check if it's not an instance of Param.
329+
if not isinstance(field_info, Param):
337330
raise AssertionError(f"Expected Param field_info, got {field_info}")
338331

339332
value = received_params.get(field.alias)
333+
340334
loc = (field_info.in_.value, field.alias)
341335

336+
# If we don't have a value, see if it's required or has a default
342337
if value is None:
343338
if field.required:
344339
errors.append(get_missing_field_error(loc=loc))
345340
else:
346341
values[field.name] = deepcopy(field.default)
347342
continue
348343

349-
# Use _validate_field like _request_body_to_args does
344+
# Finally, validate the value
350345
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
346+
351347
return values, errors
352348

353349

@@ -439,116 +435,53 @@ def _get_embed_body(
439435
return received_body, field_alias_omitted
440436

441437

442-
def _normalize_multi_query_string_with_param(
443-
query_string: dict[str, list[str]],
438+
def _normalize_multi_params(
439+
input_dict: MutableMapping[str, Any],
444440
params: Sequence[ModelField],
445-
) -> dict[str, Any]:
441+
) -> MutableMapping[str, Any]:
446442
"""
447-
Extract and normalize resolved_query_string_parameters with Pydantic model support
443+
Extract and normalize query string or header parameters with Pydantic model support.
448444
449445
Parameters
450446
----------
451-
query_string: dict
452-
A dictionary containing the initial query string parameters.
447+
input_dict: MutableMapping[str, Any]
448+
A dictionary containing the initial query string or header parameters.
453449
params: Sequence[ModelField]
454450
A sequence of ModelField objects representing parameters.
455451
456452
Returns
457453
-------
458-
A dictionary containing the processed multi_query_string_parameters.
454+
MutableMapping[str, Any]
455+
A dictionary containing the processed parameters with normalized values.
459456
"""
460-
resolved_query_string: dict[str, Any] = query_string
461-
462457
for param in params:
463-
# Handle scalar fields (existing logic)
464458
if is_scalar_field(param):
465459
try:
466-
resolved_query_string[param.alias] = query_string[param.alias][0]
460+
val = input_dict[param.alias]
461+
if isinstance(val, list) and len(val) == 1:
462+
input_dict[param.alias] = val[0]
463+
elif isinstance(val, list):
464+
pass # leave as list for multi-value
465+
# If it's a string, leave as is
467466
except KeyError:
468467
pass
469-
# Handle Pydantic models
470-
elif isinstance(param.field_info, Query) and hasattr(param.field_info, "annotation"):
471-
from pydantic import BaseModel
472-
473-
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
474-
475-
if lenient_issubclass(param.field_info.annotation, BaseModel):
476-
model_class = param.field_info.annotation
477-
model_data = {}
478-
479-
# Collect all fields for the Pydantic model
480-
for field_name, field_def in model_class.model_fields.items():
481-
field_alias = field_def.alias or field_name
482-
try:
483-
model_data[field_alias] = query_string[field_alias][0]
484-
except KeyError:
485-
if model_class.model_config.get("validate_by_name") or model_class.model_config.get(
486-
"populate_by_name",
487-
):
488-
try:
489-
model_data[field_alias] = query_string[field_name][0]
490-
except KeyError:
491-
pass
492-
493-
# Store the collected data under the param alias
494-
resolved_query_string[param.alias] = model_data
495-
496-
return resolved_query_string
497-
498-
499-
def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]):
500-
"""
501-
Extract and normalize resolved_headers_field with Pydantic model support
502-
503-
Parameters
504-
----------
505-
headers: MutableMapping[str, Any]
506-
A dictionary containing the initial header parameters.
507-
params: Sequence[ModelField]
508-
A sequence of ModelField objects representing parameters.
509-
510-
Returns
511-
-------
512-
A dictionary containing the processed headers.
513-
"""
514-
if headers:
515-
for param in params:
516-
# Handle scalar fields (existing logic)
517-
if is_scalar_field(param):
518-
try:
519-
if len(headers[param.alias]) == 1:
520-
headers[param.alias] = headers[param.alias][0]
521-
except KeyError:
522-
pass
523-
# Handle Pydantic models
524-
elif isinstance(param.field_info, Header) and hasattr(param.field_info, "annotation"):
525-
from pydantic import BaseModel
526-
527-
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
528-
529-
if lenient_issubclass(param.field_info.annotation, BaseModel):
530-
model_class = param.field_info.annotation
531-
model_data = {}
532-
533-
# Collect all fields for the Pydantic model
534-
for field_name, field_def in model_class.model_fields.items():
535-
field_alias = field_def.alias or field_name
536-
537-
# Convert snake_case to kebab-case for headers (HTTP convention)
538-
header_key = field_alias.replace("_", "-")
539-
540-
try:
541-
header_value = headers[header_key]
542-
if isinstance(header_value, list):
543-
if len(header_value) == 1:
544-
model_data[field_alias] = header_value[0]
545-
else:
546-
model_data[field_alias] = header_value
547-
else:
548-
model_data[field_alias] = header_value
549-
except KeyError:
550-
pass
551-
552-
# Store the collected data under the param alias
553-
headers[param.alias] = model_data
554-
return headers
468+
elif lenient_issubclass(param.field_info.annotation, BaseModel):
469+
model_class = param.field_info.annotation
470+
model_data = {}
471+
472+
for field_name, field_def in model_class.model_fields.items():
473+
field_alias = field_def.alias or field_name
474+
value = input_dict.get(field_alias)
475+
if value is None and (
476+
model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name")
477+
):
478+
value = input_dict.get(field_name)
479+
if value is not None:
480+
if get_origin(field_def.annotation) is list:
481+
model_data[field_alias] = value
482+
elif isinstance(value, list):
483+
model_data[field_alias] = value[0]
484+
else:
485+
model_data[field_alias] = value
486+
input_dict[param.alias] = model_data
487+
return input_dict

aws_lambda_powertools/event_handler/openapi/dependant.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
create_body_model,
1010
evaluate_forwardref,
1111
is_scalar_field,
12-
is_scalar_sequence_field,
1312
)
1413
from aws_lambda_powertools.event_handler.openapi.params import (
1514
Body,
@@ -275,10 +274,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
275274
return False
276275
elif is_scalar_field(field=param_field):
277276
return False
278-
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
279-
return False
280277
elif isinstance(param_field.field_info, (Query, Header)):
281-
# Allow Pydantic models in Query, Header, and Form parameters when explicitly annotated
282278
return False
283279
else:
284280
if not isinstance(param_field.field_info, Body):

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum
55
from typing import TYPE_CHECKING, Any, Literal
66

7-
from pydantic import BaseConfig
7+
from pydantic import BaseConfig, BaseModel, create_model
88
from pydantic.fields import FieldInfo
99
from typing_extensions import Annotated, get_args, get_origin
1010

@@ -17,6 +17,7 @@
1717
copy_field_info,
1818
field_annotation_is_scalar,
1919
get_annotation_from_field_info,
20+
lenient_issubclass,
2021
)
2122

2223
if TYPE_CHECKING:
@@ -1094,6 +1095,42 @@ def create_response_field(
10941095
return ModelField(**kwargs) # type: ignore[arg-type]
10951096

10961097

1098+
def _apply_header_underscore_conversion(
1099+
field_info: FieldInfo,
1100+
type_annotation: Any,
1101+
param_name: str,
1102+
) -> tuple[FieldInfo, Any]:
1103+
"""
1104+
Apply underscore-to-dash conversion for Header parameters.
1105+
1106+
For BaseModel: Creates new model with underscore-to-dash alias generator.
1107+
Note: If the BaseModel already has an alias generator, it will be replaced
1108+
with dash-case conversion since HTTP headers should use dash-case.
1109+
For all Header fields: Sets the parameter alias if convert_underscores is True
1110+
"""
1111+
if not isinstance(field_info, Header) or not field_info.convert_underscores:
1112+
return field_info, type_annotation
1113+
1114+
# Always set the parameter alias for Header fields (if not already set)
1115+
if not field_info.alias:
1116+
field_info.alias = param_name.replace("_", "-")
1117+
1118+
# Handle BaseModel case - create new model with dash-case alias generator
1119+
if lenient_issubclass(type_annotation, BaseModel):
1120+
# For HTTP headers, we should use dash-case regardless of existing alias generator
1121+
# This ensures consistent header naming conventions
1122+
header_aliased_model = create_model(
1123+
f"{type_annotation.__name__}WithHeaderAliases",
1124+
__base__=type_annotation,
1125+
__config__={"alias_generator": lambda name: name.replace("_", "-")},
1126+
)
1127+
1128+
type_annotation = header_aliased_model
1129+
field_info.annotation = type_annotation
1130+
1131+
return field_info, type_annotation
1132+
1133+
10971134
def _create_model_field(
10981135
field_info: FieldInfo | None,
10991136
type_annotation: Any,
@@ -1112,21 +1149,17 @@ def _create_model_field(
11121149
elif isinstance(field_info, Param) and getattr(field_info, "in_", None) is None:
11131150
field_info.in_ = ParamTypes.query
11141151

1152+
# Apply header underscore conversion
1153+
field_info, type_annotation = _apply_header_underscore_conversion(field_info, type_annotation, param_name)
1154+
11151155
# If the field_info is a Param, we use the `in_` attribute to determine the type annotation
11161156
use_annotation = get_annotation_from_field_info(type_annotation, field_info, param_name)
11171157

1118-
# If the field doesn't have a defined alias, we use the param name
1119-
if not field_info.alias and getattr(field_info, "convert_underscores", None):
1120-
alias = param_name.replace("_", "-")
1121-
else:
1122-
alias = field_info.alias or param_name
1123-
field_info.alias = alias
1124-
11251158
return create_response_field(
11261159
name=param_name,
11271160
type_=use_annotation,
11281161
default=field_info.default,
1129-
alias=alias,
1162+
alias=field_info.alias,
11301163
required=field_info.default in (Required, Undefined),
11311164
field_info=field_info,
11321165
)

0 commit comments

Comments
 (0)