Skip to content
68 changes: 62 additions & 6 deletions aws_lambda_powertools/event_handler/appsync.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Type, TypeVar

from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)

AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent)


class AppSyncResolver:
"""
Expand Down Expand Up @@ -38,7 +40,7 @@ def common_field() -> str:
return str(uuid.uuid4())
"""

current_event: AppSyncResolverEvent
current_event: AppSyncResolverEventT # type: ignore[valid-type]
lambda_context: LambdaContext

def __init__(self):
Expand All @@ -62,7 +64,9 @@ def register_resolver(func):

return register_resolver

def resolve(self, event: dict, context: LambdaContext) -> Any:
def resolve(
self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent
) -> Any:
"""Resolve field_name

Parameters
Expand All @@ -71,6 +75,56 @@ def resolve(self, event: dict, context: LambdaContext) -> Any:
Lambda event
context : LambdaContext
Lambda context
data_model:
Your data data_model to decode AppSync event, by default AppSyncResolverEvent

Example
-------

```python
from aws_lambda_powertools.event_handler import AppSyncResolver
from aws_lambda_powertools.utilities.typing import LambdaContext

@app.resolver(field_name="createSomething")
def create_something(id: str): # noqa AA03 VNE003
return id

def handler(event, context: LambdaContext):
return app.resolve(event, context)
```

**Bringing custom models**

```python
from aws_lambda_powertools import Logger, Tracer

from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.event_handler import AppSyncResolver

tracer = Tracer(service="sample_resolver")
logger = Logger(service="sample_resolver")
app = AppSyncResolver()


class MyCustomModel(AppSyncResolverEvent):
@property
def country_viewer(self) -> str:
return self.request_headers.get("cloudfront-viewer-country")


@app.resolver(field_name="listLocations")
@app.resolver(field_name="locations")
def get_locations(name: str, description: str = ""):
if app.current_event.country_viewer == "US":
...
return name + description


@logger.inject_lambda_context(correlation_id_path=correlation_paths.APPSYNC_RESOLVER)
@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context, data_model=MyCustomModel)
```

Returns
-------
Expand All @@ -82,7 +136,7 @@ def resolve(self, event: dict, context: LambdaContext) -> Any:
ValueError
If we could not find a field resolver
"""
self.current_event = AppSyncResolverEvent(event)
self.current_event = data_model(event)
self.lambda_context = context
resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name)
return resolver(**self.current_event.arguments)
Expand All @@ -108,6 +162,8 @@ def _get_resolver(self, type_name: str, field_name: str) -> Callable:
raise ValueError(f"No resolver found for '{full_name}'")
return resolver["func"]

def __call__(self, event, context) -> Any:
def __call__(
self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent
) -> Any:
"""Implicit lambda handler which internally calls `resolve`"""
return self.resolve(event, context)
return self.resolve(event, context, data_model)
112 changes: 112 additions & 0 deletions docs/core/event_handler/appsync.md
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,118 @@ Use the following code for `merchantInfo` and `searchMerchant` functions respect
}
```

### Custom data models

You can subclass `AppSyncResolverEvent` to bring your own set of methods to handle incoming events, by using `data_model` param in the `resolve` method.


=== "custom_model.py"

```python hl_lines="11-14 19 26"
from aws_lambda_powertools import Logger, Tracer

from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.event_handler import AppSyncResolver

tracer = Tracer(service="sample_resolver")
logger = Logger(service="sample_resolver")
app = AppSyncResolver()


class MyCustomModel(AppSyncResolverEvent):
@property
def country_viewer(self) -> str:
return self.request_headers.get("cloudfront-viewer-country")

@app.resolver(field_name="listLocations")
@app.resolver(field_name="locations")
def get_locations(name: str, description: str = ""):
if app.current_event.country_viewer == "US":
...
return name + description

@logger.inject_lambda_context(correlation_id_path=correlation_paths.APPSYNC_RESOLVER)
@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context, data_model=MyCustomModel)
```

=== "schema.graphql"

```typescript hl_lines="6 20"
schema {
query: Query
}

type Query {
listLocations: [Location]
}

type Location {
id: ID!
name: String!
description: String
address: String
}

type Merchant {
id: String!
name: String!
description: String
locations: [Location]
}
```

=== "listLocations_event.json"

```json
{
"arguments": {},
"identity": null,
"source": null,
"request": {
"headers": {
"x-forwarded-for": "1.2.3.4, 5.6.7.8",
"accept-encoding": "gzip, deflate, br",
"cloudfront-viewer-country": "NL",
"cloudfront-is-tablet-viewer": "false",
"referer": "https://eu-west-1.console.aws.amazon.com/appsync/home?region=eu-west-1",
"via": "2.0 9fce949f3749407c8e6a75087e168b47.cloudfront.net (CloudFront)",
"cloudfront-forwarded-proto": "https",
"origin": "https://eu-west-1.console.aws.amazon.com",
"x-api-key": "da1-c33ullkbkze3jg5hf5ddgcs4fq",
"content-type": "application/json",
"x-amzn-trace-id": "Root=1-606eb2f2-1babc433453a332c43fb4494",
"x-amz-cf-id": "SJw16ZOPuMZMINx5Xcxa9pB84oMPSGCzNOfrbJLvd80sPa0waCXzYQ==",
"content-length": "114",
"x-amz-user-agent": "AWS-Console-AppSync/",
"x-forwarded-proto": "https",
"host": "ldcvmkdnd5az3lm3gnf5ixvcyy.appsync-api.eu-west-1.amazonaws.com",
"accept-language": "en-US,en;q=0.5",
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:78.0) Gecko/20100101 Firefox/78.0",
"cloudfront-is-desktop-viewer": "true",
"cloudfront-is-mobile-viewer": "false",
"accept": "*/*",
"x-forwarded-port": "443",
"cloudfront-is-smarttv-viewer": "false"
}
},
"prev": null,
"info": {
"parentTypeName": "Query",
"selectionSetList": [
"id",
"name",
"description"
],
"selectionSetGraphQL": "{\n id\n name\n description\n}",
"fieldName": "listLocations",
"variables": {}
},
"stash": {}
}
```

## Testing your code

You can test your resolvers by passing a mocked or actual AppSync Lambda event that you're expecting.
Expand Down
23 changes: 23 additions & 0 deletions tests/functional/event_handler/test_appsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,26 @@ async def get_async():

# THEN
assert asyncio.run(result) == "value"


def test_resolve_custom_data_model():
# Check whether we can handle an example appsync direct resolver
mock_event = load_event("appSyncDirectResolver.json")

class MyCustomModel(AppSyncResolverEvent):
@property
def country_viewer(self):
return self.request_headers.get("cloudfront-viewer-country")

app = AppSyncResolver()

@app.resolver(field_name="createSomething")
def create_something(id: str): # noqa AA03 VNE003
return id

# Call the implicit handler
result = app(event=mock_event, context=LambdaContext(), data_model=MyCustomModel)

assert result == "my identifier"

assert app.current_event.country_viewer == "US"