Skip to content
150 changes: 113 additions & 37 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
# The origin matched an allowed origin, so return the CORS headers
headers = {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
"Access-Control-Allow-Headers": CORSConfig.build_allow_methods(self.allow_headers),
}

if self.expose_headers:
Expand All @@ -222,6 +222,23 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
headers["Access-Control-Allow-Credentials"] = "true"
return headers

@staticmethod
def build_allow_methods(methods: Set[str]) -> str:
"""Build sorted comma delimited methods for Access-Control-Allow-Methods header

Parameters
----------
methods : set[str]
Set of HTTP Methods

Returns
-------
set[str]
Formatted string with all HTTP Methods allowed for CORS e.g., `GET, OPTIONS`

"""
return ",".join(sorted(methods))


class Response(Generic[ResponseT]):
"""Response data class that provides greater control over what is returned from the proxy event"""
Expand Down Expand Up @@ -282,16 +299,16 @@ def __init__(
func: Callable,
cors: bool,
compress: bool,
cache_control: Optional[str],
summary: Optional[str],
description: Optional[str],
responses: Optional[Dict[int, OpenAPIResponse]],
response_description: Optional[str],
tags: Optional[List[str]],
operation_id: Optional[str],
include_in_schema: bool,
security: Optional[List[Dict[str, List[str]]]],
middlewares: Optional[List[Callable[..., Response]]],
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: Optional[str] = None,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
include_in_schema: bool = True,
security: Optional[List[Dict[str, List[str]]]] = None,
middlewares: Optional[List[Callable[..., Response]]] = None,
):
"""

Expand Down Expand Up @@ -1406,7 +1423,6 @@ def _registered_api_adapter(
"""
route_args: Dict = app.context.get("_route_args", {})
logger.debug(f"Calling API Route Handler: {route_args}")

return app._to_response(next_middleware(**route_args))


Expand Down Expand Up @@ -1967,6 +1983,36 @@ def register_resolver(func: Callable):
def resolve(self, event, context) -> Dict[str, Any]:
"""Resolves the response based on the provide event and decorator routes

## Internals

Request processing chain is triggered by a Route object being called _(`_call_route` -> `__call__`)_:

1. **When a route is matched**
1.1. Exception handlers _(if any exception bubbled up and caught)_
1.2. Global middlewares _(before, and after on the way back)_
1.3. Path level middleware _(before, and after on the way back)_
1.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
1.5. Run actual route
2. **When a route is NOT matched**
2.1. Exception handlers _(if any exception bubbled up and caught)_
2.2. Global middlewares _(before, and after on the way back)_
2.3. Path level middleware _(before, and after on the way back)_
2.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
2.5. Run 404 route handler
3. **When a route is a pre-flight CORS (often not matched)**
3.1. Exception handlers _(if any exception bubbled up and caught)_
3.2. Global middlewares _(before, and after on the way back)_
3.3. Path level middleware _(before, and after on the way back)_
3.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
3.5. Return 204 with appropriate CORS headers
4. **When a route is matched with Data Validation enabled**
4.1. Exception handlers _(if any exception bubbled up and caught)_
4.2. Data Validation middleware _(before, and after on the way back)_
4.3. Global middlewares _(before, and after on the way back)_
4.4. Path level middleware _(before, and after on the way back)_
4.5. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
4.6. Run actual route

Parameters
----------
event: Dict[str, Any]
Expand Down Expand Up @@ -2090,7 +2136,9 @@ def _resolve(self) -> ResponseBuilder:
method = self.current_event.http_method.upper()
path = self._remove_prefix(self.current_event.path)

for route in self._static_routes + self._dynamic_routes:
registered_routes = self._static_routes + self._dynamic_routes

for route in registered_routes:
if method != route.method:
continue
match_results: Optional[Match] = route.rule.match(path)
Expand All @@ -2102,8 +2150,7 @@ def _resolve(self) -> ResponseBuilder:
route_keys = self._convert_matches_into_route_keys(match_results)
return self._call_route(route, route_keys) # pass fn args

logger.debug(f"No match found for path {path} and method {method}")
return self._not_found(method)
return self._handle_not_found(method=method, path=path)

def _remove_prefix(self, path: str) -> str:
"""Remove the configured prefix from the path"""
Expand Down Expand Up @@ -2141,36 +2188,65 @@ def _path_starts_with(path: str, prefix: str):

return path.startswith(prefix + "/")

def _not_found(self, method: str) -> ResponseBuilder:
def _handle_not_found(self, method: str, path: str) -> ResponseBuilder:
"""Called when no matching route was found and includes support for the cors preflight response"""
headers = {}
if self._cors:
logger.debug("CORS is enabled, updating headers.")
extracted_origin_header = extract_origin_header(self.current_event.resolved_headers_field)
headers.update(self._cors.to_dict(extracted_origin_header))

if method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with null response")
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(
response=Response(status_code=204, content_type=None, headers=headers, body=""),
serializer=self._serializer,
)
logger.debug(f"No match found for path {path} and method {method}")

handler = self._lookup_exception_handler(NotFoundError)
if handler:
return self._response_builder_class(response=handler(NotFoundError()), serializer=self._serializer)
def not_found_handler():
"""Route handler for 404s

It handles in the following order:

1. Pre-flight CORS requests (OPTIONS)
2. Detects and calls custom HTTP 404 handler
3. Returns standard 404 along with CORS headers

return self._response_builder_class(
response=Response(
Returns
-------
Response
HTTP 404 response
"""
_headers: Dict[str, Any] = {}

# Pre-flight request? Return immediately to avoid browser error
if self._cors and method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with empty response")
_headers["Access-Control-Allow-Methods"] = CORSConfig.build_allow_methods(self._cors_methods)

return Response(status_code=204, content_type=None, headers=_headers, body="")

# Customer registered 404 route? Call it.
custom_not_found_handler = self._lookup_exception_handler(NotFoundError)
if custom_not_found_handler:
return custom_not_found_handler(NotFoundError())

# No CORS and no custom 404 fn? Default response
return Response(
status_code=HTTPStatus.NOT_FOUND.value,
content_type=content_types.APPLICATION_JSON,
headers=headers,
headers=_headers,
body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"},
),
serializer=self._serializer,
)

# We create a route to trigger entire request chain (middleware+exception handlers)
route = Route(
rule=self._compile_regex(r".*"),
method=method,
path=path,
func=not_found_handler,
cors=self._cors_enabled,
compress=False,
)

# Add matched Route reference into the Resolver context
self.append_context(_route=route, _path=path)

# Kick-off request chain:
# -> exception_handlers()
# --> middlewares()
# ---> not_found_route()
return self._call_route(route=route, route_arguments={})

def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
try:
Expand Down
58 changes: 58 additions & 0 deletions tests/functional/event_handler/test_api_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
APIGatewayHttpResolver,
ApiGatewayResolver,
APIGatewayRestResolver,
CORSConfig,
ProxyEventType,
Response,
Router,
Expand Down Expand Up @@ -506,3 +507,60 @@ def post_lambda():
result = resolver(event, {})
assert result["statusCode"] == 200
assert result["multiValueHeaders"]["X-Correlation-Id"][0] == resolver.current_event.request_context.request_id # type: ignore[attr-defined] # noqa: E501


@pytest.mark.parametrize(
"app, event",
[
(ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT),
(APIGatewayRestResolver(), API_REST_EVENT),
(APIGatewayHttpResolver(), API_RESTV2_EVENT),
],
)
def test_global_middleware_not_found(app: ApiGatewayResolver, event):
# GIVEN global middleware is registered

def middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
# add additional data to Router Context
ret = next_middleware(app)
ret.body = "middleware works"
return ret

app.use(middlewares=[middleware])

@app.get("/this/path/does/not/exist")
def nope() -> dict: ...

# WHEN calling the event handler for an unregistered route /my/path
result = app(event, {})

# THEN process event correctly as HTTP 404
# AND ensure middlewares are called
assert result["statusCode"] == 404
assert result["body"] == "middleware works"


def test_global_middleware_not_found_preflight():
# GIVEN global middleware is registered

app = ApiGatewayResolver(cors=CORSConfig(), proxy_type=ProxyEventType.APIGatewayProxyEvent)
event = {**API_REST_EVENT, "httpMethod": "OPTIONS"}

def middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
# add additional data to Router Context
ret = next_middleware(app)
ret.body = "middleware works"
return ret

app.use(middlewares=[middleware])

@app.get("/this/path/does/not/exist")
def nope() -> dict: ...

# WHEN calling the event handler for an unregistered route /my/path OPTIONS
result = app(event, {})

# THEN process event correctly as HTTP 204 (not 404)
# AND ensure middlewares are called
assert result["statusCode"] == 204
assert result["body"] == "middleware works"