diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 326c95dadfb..eabc7a65a10 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ # We use poetry to run formatting and linting before commit/push -# Longers checks such as tests, security and complexity baseline +# Longer checks such as tests, security and complexity baseline # are run as part of CI to prevent slower feedback loop # All checks can be run locally via `make pr` diff --git a/CHANGELOG.md b/CHANGELOG.md index dbbfb6ad2b8..e1128ce44e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.5.0] - 2020-09-04 + +### Added +- **Logger**: Add `xray_trace_id` to log output to improve integration with CloudWatch Service Lens +- **Logger**: Allow reordering of logged output +- **Utilities**: Add new `SQS batch processing` utility to handle partial failures in processing message batches +- **Utilities**: Add typing utility providing static type for lambda context object +- **Utilities**: Add `transform=auto` in parameters utility to deserialize parameter values based on the key name + +### Fixed +- **Logger**: The value of `json_default` formatter is no longer written to logs + ## [1.4.0] - 2020-08-25 ### Added diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 45c5e93da78..1e330afa47b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -69,6 +69,19 @@ opensource-codeofconduct@amazon.com with any additional questions or comments. ## Security issue notifications If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. +## Troubleshooting + +### API reference documentation + +When you are working on the codebase and you use the local API reference documentation to preview your changes, you might see the following message: `Module aws_lambda_powertools not found`. + +This happens when: + +* You did not install the local dev environment yet + - You can install dev deps with `make dev` command +* The code in the repository is raising an exception while the `pdoc` is scanning the codebase + - Unfortunately, this exception is not shown to you, but if you run, `poetry run pdoc --pdf aws_lambda_powertools`, the exception is shown and you can prevent the exception from being raised + - Once resolved the documentation should load correctly again ## Licensing diff --git a/README.md b/README.md index 68090913dbe..10b5a18488b 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,12 @@ ![Build](https://github.com/awslabs/aws-lambda-powertools/workflows/Powertools%20Python/badge.svg?branch=master) ![PythonSupport](https://img.shields.io/static/v1?label=python&message=3.6%20|%203.7|%203.8&color=blue?style=flat-square&logo=python) ![PyPI version](https://badge.fury.io/py/aws-lambda-powertools.svg) ![PyPi monthly downloads](https://img.shields.io/pypi/dm/aws-lambda-powertools) -A suite of utilities for AWS Lambda Functions that makes tracing with AWS X-Ray, structured logging and creating custom metrics asynchronously easier. +A suite of utilities for AWS Lambda functions that makes tracing with AWS X-Ray, structured logging and creating custom metrics asynchronously easier. **[📜Documentation](https://awslabs.github.io/aws-lambda-powertools-python/)** | **[API Docs](https://awslabs.github.io/aws-lambda-powertools-python/api/)** | **[🐍PyPi](https://pypi.org/project/aws-lambda-powertools/)** | **[Feature request](https://github.com/awslabs/aws-lambda-powertools-python/issues/new?assignees=&labels=feature-request%2C+triage&template=feature_request.md&title=)** | **[🐛Bug Report](https://github.com/awslabs/aws-lambda-powertools-python/issues/new?assignees=&labels=bug%2C+triage&template=bug_report.md&title=)** | **[Kitchen sink example](https://github.com/awslabs/aws-lambda-powertools-python/tree/develop/example)** | **[Detailed blog post](https://aws.amazon.com/blogs/opensource/simplifying-serverless-best-practices-with-lambda-powertools/)** +> **Join us on the AWS Developers Slack at `#lambda-powertools`** - **[Invite, if you don't have an account](https://join.slack.com/t/awsdevelopers/shared_invite/zt-gu30gquv-EhwIYq3kHhhysaZ2aIX7ew)** + ## Features * **[Tracing](https://awslabs.github.io/aws-lambda-powertools-python/core/tracer/)** - Decorators and utilities to trace Lambda function handlers, and both synchronous and asynchronous functions diff --git a/aws_lambda_powertools/logging/formatter.py b/aws_lambda_powertools/logging/formatter.py index cb3bb397348..647abf33a8a 100644 --- a/aws_lambda_powertools/logging/formatter.py +++ b/aws_lambda_powertools/logging/formatter.py @@ -1,27 +1,6 @@ import json import logging -from typing import Any - - -def json_formatter(unserialized_value: Any): - """JSON custom serializer to cast unserialisable values to strings. - - Example - ------- - - **Serialize unserialisable value to string** - - class X: pass - value = {"x": X()} - - json.dumps(value, default=json_formatter) - - Parameters - ---------- - unserialized_value: Any - Python object unserializable by JSON - """ - return str(unserialized_value) +import os class JsonFormatter(logging.Formatter): @@ -39,22 +18,40 @@ def __init__(self, **kwargs): """Return a JsonFormatter instance. The `json_default` kwarg is used to specify a formatter for otherwise - unserialisable values. It must not throw. Defaults to a function that + unserializable values. It must not throw. Defaults to a function that coerces the value to a string. + The `log_record_order` kwarg is used to specify the order of the keys used in + the structured json logs. By default the order is: "level", "location", "message", "timestamp", + "service" and "sampling_rate". + Other kwargs are used to specify log field format strings. """ - datefmt = kwargs.pop("datefmt", None) - - super(JsonFormatter, self).__init__(datefmt=datefmt) + # Set the default unserializable function, by default values will be cast as str. + self.default_json_formatter = kwargs.pop("json_default", str) + # Set the insertion order for the log messages + self.format_dict = dict.fromkeys(kwargs.pop("log_record_order", ["level", "location", "message", "timestamp"])) self.reserved_keys = ["timestamp", "level", "location"] - self.format_dict = { - "timestamp": "%(asctime)s", + # Set the date format used by `asctime` + super(JsonFormatter, self).__init__(datefmt=kwargs.pop("datefmt", None)) + + self.format_dict.update(self._build_root_keys(**kwargs)) + + @staticmethod + def _build_root_keys(**kwargs): + return { "level": "%(levelname)s", "location": "%(funcName)s:%(lineno)d", + "timestamp": "%(asctime)s", + **kwargs, } - self.format_dict.update(kwargs) - self.default_json_formatter = kwargs.pop("json_default", json_formatter) + + @staticmethod + def _get_latest_trace_id(): + xray_trace_id = os.getenv("_X_AMZN_TRACE_ID") + trace_id = xray_trace_id.split(";")[0].replace("Root=", "") if xray_trace_id else None + + return trace_id def update_formatter(self, **kwargs): self.format_dict.update(kwargs) @@ -64,6 +61,7 @@ def format(self, record): # noqa: A003 record_dict["asctime"] = self.formatTime(record, self.datefmt) log_dict = {} + for key, value in self.format_dict.items(): if value and key in self.reserved_keys: # converts default logging expr to its record value @@ -84,19 +82,19 @@ def format(self, record): # noqa: A003 except (json.decoder.JSONDecodeError, TypeError, ValueError): pass - if record.exc_info: + if record.exc_info and not record.exc_text: # Cache the traceback text to avoid converting it multiple times # (it's constant anyway) # from logging.Formatter:format - if not record.exc_text: # pragma: no cover - record.exc_text = self.formatException(record.exc_info) + record.exc_text = self.formatException(record.exc_info) if record.exc_text: log_dict["exception"] = record.exc_text - json_record = json.dumps(log_dict, default=self.default_json_formatter) + # fetch latest X-Ray Trace ID, if any + log_dict.update({"xray_trace_id": self._get_latest_trace_id()}) - if hasattr(json_record, "decode"): # pragma: no cover - json_record = json_record.decode("utf-8") + # Filter out top level key with values that are None + log_dict = {k: v for k, v in log_dict.items() if v is not None} - return json_record + return json.dumps(log_dict, default=self.default_json_formatter) diff --git a/aws_lambda_powertools/logging/logger.py b/aws_lambda_powertools/logging/logger.py index b566ee83a83..3b188bdd9c4 100644 --- a/aws_lambda_powertools/logging/logger.py +++ b/aws_lambda_powertools/logging/logger.py @@ -136,16 +136,6 @@ def __getattr__(self, name): # https://github.com/awslabs/aws-lambda-powertools-python/issues/97 return getattr(self._logger, name) - def _get_log_level(self, level: Union[str, int]) -> Union[str, int]: - """ Returns preferred log level set by the customer in upper case """ - if isinstance(level, int): - return level - - log_level: str = level or os.getenv("LOG_LEVEL") - log_level = log_level.upper() if log_level is not None else logging.INFO - - return log_level - def _get_logger(self): """ Returns a Logger named {self.service}, or {self.service.filename} for child loggers""" logger_name = self.service @@ -154,17 +144,6 @@ def _get_logger(self): return logging.getLogger(logger_name) - def _get_caller_filename(self): - """ Return caller filename by finding the caller frame """ - # Current frame => _get_logger() - # Previous frame => logger.py - # Before previous frame => Caller - frame = inspect.currentframe() - caller_frame = frame.f_back.f_back.f_back - filename = caller_frame.f_globals["__name__"] - - return filename - def _init_logger(self, **kwargs): """Configures new logger""" @@ -207,6 +186,8 @@ def inject_lambda_context(self, lambda_handler: Callable[[Dict, Any], Any] = Non Parameters ---------- + lambda_handler : Callable + Method to inject the lambda context log_event : bool, optional Instructs logger to log Lambda Event, by default False @@ -254,14 +235,14 @@ def handler(event, context): @functools.wraps(lambda_handler) def decorate(event, context): + lambda_context = build_lambda_context_model(context) + cold_start = _is_cold_start() + self.structure_logs(append=True, cold_start=cold_start, **lambda_context.__dict__) + if log_event: logger.debug("Event received") self.info(event) - lambda_context = build_lambda_context_model(context) - cold_start = _is_cold_start() - - self.structure_logs(append=True, cold_start=cold_start, **lambda_context.__dict__) return lambda_handler(event, context) return decorate @@ -291,6 +272,29 @@ def structure_logs(self, append: bool = False, **kwargs): # Set a new formatter for a logger handler handler.setFormatter(JsonFormatter(**self._default_log_keys, **kwargs)) + @staticmethod + def _get_log_level(level: Union[str, int]) -> Union[str, int]: + """ Returns preferred log level set by the customer in upper case """ + if isinstance(level, int): + return level + + log_level: str = level or os.getenv("LOG_LEVEL") + log_level = log_level.upper() if log_level is not None else logging.INFO + + return log_level + + @staticmethod + def _get_caller_filename(): + """ Return caller filename by finding the caller frame """ + # Current frame => _get_logger() + # Previous frame => logger.py + # Before previous frame => Caller + frame = inspect.currentframe() + caller_frame = frame.f_back.f_back.f_back + filename = caller_frame.f_globals["__name__"] + + return filename + def set_package_logger( level: Union[str, int] = logging.DEBUG, stream: sys.stdout = None, formatter: logging.Formatter = None diff --git a/aws_lambda_powertools/metrics/base.py b/aws_lambda_powertools/metrics/base.py index 175097d9c10..bff0c84e03f 100644 --- a/aws_lambda_powertools/metrics/base.py +++ b/aws_lambda_powertools/metrics/base.py @@ -112,7 +112,7 @@ def add_metric(self, name: str, unit: MetricUnit, value: Union[float, int]): Metric name unit : MetricUnit `aws_lambda_powertools.helper.models.MetricUnit` - value : float + value : Union[float, int] Metric value Raises @@ -146,6 +146,8 @@ def serialize_metric_set(self, metrics: Dict = None, dimensions: Dict = None, me Dictionary of metrics to serialize, by default None dimensions : Dict, optional Dictionary of dimensions to serialize, by default None + metadata: Dict, optional + Dictionary of metadata to serialize, by default None Example ------- @@ -183,7 +185,7 @@ def serialize_metric_set(self, metrics: Dict = None, dimensions: Dict = None, me metric_names_and_values: Dict[str, str] = {} # { "metric_name": 1.0 } for metric_name in metrics: - metric: str = metrics[metric_name] + metric: dict = metrics[metric_name] metric_value: int = metric.get("Value", 0) metric_unit: str = metric.get("Unit", "") @@ -257,7 +259,7 @@ def add_metadata(self, key: str, value: Any): Parameters ---------- - name : str + key : str Metadata key value : any Metadata value diff --git a/aws_lambda_powertools/utilities/batch/__init__.py b/aws_lambda_powertools/utilities/batch/__init__.py new file mode 100644 index 00000000000..d308a56abda --- /dev/null +++ b/aws_lambda_powertools/utilities/batch/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- + +""" +Batch processing utility +""" + +from .base import BasePartialProcessor, batch_processor +from .sqs import PartialSQSProcessor, sqs_batch_processor + +__all__ = ("BasePartialProcessor", "PartialSQSProcessor", "batch_processor", "sqs_batch_processor") diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py new file mode 100644 index 00000000000..38fc1ca75fc --- /dev/null +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- + +""" +Batch processing utilities +""" + +import logging +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Tuple + +from aws_lambda_powertools.middleware_factory import lambda_handler_decorator + +logger = logging.getLogger(__name__) + + +class BasePartialProcessor(ABC): + """ + Abstract class for batch processors. + """ + + def __init__(self): + self.success_messages: List = [] + self.fail_messages: List = [] + self.exceptions: List = [] + + @abstractmethod + def _prepare(self): + """ + Prepare context manager. + """ + raise NotImplementedError() + + @abstractmethod + def _clean(self): + """ + Clear context manager. + """ + raise NotImplementedError() + + @abstractmethod + def _process_record(self, record: Any): + """ + Process record with handler. + """ + raise NotImplementedError() + + def process(self) -> List[Tuple]: + """ + Call instance's handler for each record. + """ + return [self._process_record(record) for record in self.records] + + def __enter__(self): + self._prepare() + return self + + def __exit__(self, exception_type, exception_value, traceback): + self._clean() + + def __call__(self, records: List[Any], handler: Callable): + """ + Set instance attributes before execution + + Parameters + ---------- + records: List[Any] + List with objects to be processed. + handler: Callable + Callable to process "records" entries. + """ + self.records = records + self.handler = handler + return self + + def success_handler(self, record: Any, result: Any): + """ + Success callback + + Returns + ------- + tuple + "success", result, original record + """ + entry = ("success", result, record) + self.success_messages.append(record) + return entry + + def failure_handler(self, record: Any, exception: Exception): + """ + Failure callback + + Returns + ------- + tuple + "fail", exceptions args, original record + """ + entry = ("fail", exception.args, record) + logger.debug(f"Record processing exception: {exception}") + self.exceptions.append(exception) + self.fail_messages.append(record) + return entry + + +@lambda_handler_decorator +def batch_processor( + handler: Callable, event: Dict, context: Dict, record_handler: Callable, processor: BasePartialProcessor = None +): + """ + Middleware to handle batch event processing + + Parameters + ---------- + handler: Callable + Lambda's handler + event: Dict + Lambda's Event + context: Dict + Lambda's Context + record_handler: Callable + Callable to process each record from the batch + processor: PartialSQSProcessor + Batch Processor to handle partial failure cases + + Examples + -------- + **Processes Lambda's event with PartialSQSProcessor** + >>> from aws_lambda_powertools.utilities.batch import batch_processor, PartialSQSProcessor + >>> + >>> def record_handler(record): + >>> return record["body"] + >>> + >>> @batch_processor(record_handler=record_handler, processor=PartialSQSProcessor()) + >>> def handler(event, context): + >>> return {"StatusCode": 200} + + Limitations + ----------- + * Async batch processors + + """ + records = event["Records"] + + with processor(records, record_handler): + processor.process() + + return handler(event, context) diff --git a/aws_lambda_powertools/utilities/batch/exceptions.py b/aws_lambda_powertools/utilities/batch/exceptions.py new file mode 100644 index 00000000000..3e456eacec4 --- /dev/null +++ b/aws_lambda_powertools/utilities/batch/exceptions.py @@ -0,0 +1,7 @@ +""" +Batch processing exceptions +""" + + +class SQSBatchProcessingError(Exception): + """When at least one message within a batch could not be processed""" diff --git a/aws_lambda_powertools/utilities/batch/sqs.py b/aws_lambda_powertools/utilities/batch/sqs.py new file mode 100644 index 00000000000..4a4aa9c98b1 --- /dev/null +++ b/aws_lambda_powertools/utilities/batch/sqs.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- + +""" +Batch SQS utilities +""" +import logging +from typing import Callable, Dict, List, Optional, Tuple + +import boto3 +from botocore.config import Config + +from ...middleware_factory import lambda_handler_decorator +from .base import BasePartialProcessor +from .exceptions import SQSBatchProcessingError + +logger = logging.getLogger(__name__) + + +class PartialSQSProcessor(BasePartialProcessor): + """ + Amazon SQS batch processor to delete successes from the Queue. + + The whole batch will be processed, even if failures occur. After all records are processed, + SQSBatchProcessingError will be raised if there were any failures, causing messages to + be returned to the SQS queue. This behaviour can be disabled by passing suppress_exception. + + Parameters + ---------- + config: Config + botocore config object + suppress_exception: bool, optional + Supress exception raised if any messages fail processing, by default False + + + Example + ------- + **Process batch triggered by SQS** + + >>> from aws_lambda_powertools.utilities.batch import PartialSQSProcessor + >>> + >>> def record_handler(record): + >>> return record["body"] + >>> + >>> def handler(event, context): + >>> records = event["Records"] + >>> processor = PartialSQSProcessor() + >>> + >>> with processor(records=records, handler=record_handler): + >>> result = processor.process() + >>> + >>> # Case a partial failure occurred, all successful executions + >>> # have been deleted from the queue after context's exit. + >>> + >>> return result + """ + + def __init__(self, config: Optional[Config] = None, suppress_exception: bool = False): + """ + Initializes sqs client. + """ + config = config or Config() + self.client = boto3.client("sqs", config=config) + self.suppress_exception = suppress_exception + + super().__init__() + + def _get_queue_url(self) -> Optional[str]: + """ + Format QueueUrl from first records entry + """ + if not getattr(self, "records", None): + return + + *_, account_id, queue_name = self.records[0]["eventSourceARN"].split(":") + return f"{self.client._endpoint.host}/{account_id}/{queue_name}" + + def _get_entries_to_clean(self) -> List: + """ + Format messages to use in batch deletion + """ + return [{"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]} for msg in self.success_messages] + + def _process_record(self, record) -> Tuple: + """ + Process a record with instance's handler + + Parameters + ---------- + record: Any + An object to be processed. + """ + try: + result = self.handler(record) + return self.success_handler(record, result) + except Exception as exc: + return self.failure_handler(record, exc) + + def _prepare(self): + """ + Remove results from previous execution. + """ + self.success_messages.clear() + self.fail_messages.clear() + + def _clean(self): + """ + Delete messages from Queue in case of partial failure. + """ + # If all messages were successful, fall back to the default SQS - + # Lambda behaviour which deletes messages if Lambda responds successfully + if not self.fail_messages: + logger.debug(f"All {len(self.success_messages)} records successfully processed") + return + + queue_url = self._get_queue_url() + entries_to_remove = self._get_entries_to_clean() + + delete_message_response = self.client.delete_message_batch(QueueUrl=queue_url, Entries=entries_to_remove) + + if self.suppress_exception: + logger.debug(f"{len(self.fail_messages)} records failed processing, but exceptions are suppressed") + else: + logger.debug(f"{len(self.fail_messages)} records failed processing, raising exception") + raise SQSBatchProcessingError(list(self.exceptions)) + + return delete_message_response + + +@lambda_handler_decorator +def sqs_batch_processor( + handler: Callable, + event: Dict, + context: Dict, + record_handler: Callable, + config: Optional[Config] = None, + suppress_exception: bool = False, +): + """ + Middleware to handle SQS batch event processing + + Parameters + ---------- + handler: Callable + Lambda's handler + event: Dict + Lambda's Event + context: Dict + Lambda's Context + record_handler: Callable + Callable to process each record from the batch + config: Config + botocore config object + suppress_exception: bool, optional + Supress exception raised if any messages fail processing, by default False + + Examples + -------- + **Processes Lambda's event with PartialSQSProcessor** + >>> from aws_lambda_powertools.utilities.batch import sqs_batch_processor + >>> + >>> def record_handler(record): + >>> return record["body"] + >>> + >>> @sqs_batch_processor(record_handler=record_handler) + >>> def handler(event, context): + >>> return {"StatusCode": 200} + + Limitations + ----------- + * Async batch processors + + """ + config = config or Config() + processor = PartialSQSProcessor(config=config, suppress_exception=suppress_exception) + + records = event["Records"] + + with processor(records, record_handler): + processor.process() + + return handler(event, context) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index 274cd96aace..7ce0c9e4d2e 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -15,6 +15,9 @@ ExpirableValue = namedtuple("ExpirableValue", ["value", "ttl"]) # These providers will be dynamically initialized on first use of the helper functions DEFAULT_PROVIDERS = {} +TRANSFORM_METHOD_JSON = "json" +TRANSFORM_METHOD_BINARY = "binary" +SUPPORTED_TRANSFORM_METHODS = [TRANSFORM_METHOD_JSON, TRANSFORM_METHOD_BINARY] class BaseProvider(ABC): @@ -115,8 +118,8 @@ def get_multiple( Maximum age of the cached value transform: str, optional Optional transformation of the parameter value. Supported values - are "json" for JSON strings and "binary" for base 64 encoded - values. + are "json" for JSON strings, "binary" for base 64 encoded + values or "auto" which looks at the attribute key to determine the type. raise_on_transform_error: bool, optional Raises an exception if any transform fails, otherwise this will return a None value for each transform that failed @@ -145,7 +148,11 @@ def get_multiple( if transform is not None: for (key, value) in values.items(): - values[key] = transform_value(value, transform, raise_on_transform_error) + _transform = get_transform_method(key, transform) + if _transform is None: + continue + + values[key] = transform_value(value, _transform, raise_on_transform_error) self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),) @@ -159,6 +166,45 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: raise NotImplementedError() +def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[str]: + """ + Determine the transform method + + Examples + ------- + >>> get_transform_method("key", "any_other_value") + 'any_other_value' + >>> get_transform_method("key.json", "auto") + 'json' + >>> get_transform_method("key.binary", "auto") + 'binary' + >>> get_transform_method("key", "auto") + None + >>> get_transform_method("key", None) + None + + Parameters + --------- + key: str + Only used when the tranform is "auto". + transform: str, optional + Original transform method, only "auto" will try to detect the transform method by the key + + Returns + ------ + Optional[str]: + The transform method either when transform is "auto" then None, "json" or "binary" is returned + or the original transform method + """ + if transform != "auto": + return transform + + for transform_method in SUPPORTED_TRANSFORM_METHODS: + if key.endswith("." + transform_method): + return transform_method + return None + + def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]: """ Apply a transform to a value @@ -180,9 +226,9 @@ def transform_value(value: str, transform: str, raise_on_transform_error: bool = """ try: - if transform == "json": + if transform == TRANSFORM_METHOD_JSON: return json.loads(value) - elif transform == "binary": + elif transform == TRANSFORM_METHOD_BINARY: return base64.b64decode(value) else: raise ValueError(f"Invalid transform type '{transform}'") diff --git a/aws_lambda_powertools/utilities/parameters/secrets.py b/aws_lambda_powertools/utilities/parameters/secrets.py index 67cb94c340b..e3981d22bcc 100644 --- a/aws_lambda_powertools/utilities/parameters/secrets.py +++ b/aws_lambda_powertools/utilities/parameters/secrets.py @@ -27,7 +27,7 @@ class SecretsProvider(BaseProvider): >>> from aws_lambda_powertools.utilities.parameters import SecretsProvider >>> secrets_provider = SecretsProvider() >>> - >>> value secrets_provider.get("my-parameter") + >>> value = secrets_provider.get("my-parameter") >>> >>> print(value) My parameter value diff --git a/aws_lambda_powertools/utilities/typing/__init__.py b/aws_lambda_powertools/utilities/typing/__init__.py new file mode 100644 index 00000000000..626a0fd6fcf --- /dev/null +++ b/aws_lambda_powertools/utilities/typing/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +""" +Typing for developer ease in the IDE +""" + +from .lambda_context import LambdaContext + +__all__ = [ + "LambdaContext", +] diff --git a/aws_lambda_powertools/utilities/typing/lambda_client_context.py b/aws_lambda_powertools/utilities/typing/lambda_client_context.py new file mode 100644 index 00000000000..5b9e9506b4c --- /dev/null +++ b/aws_lambda_powertools/utilities/typing/lambda_client_context.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +from typing import Any, Dict + +from aws_lambda_powertools.utilities.typing.lambda_client_context_mobile_client import LambdaClientContextMobileClient + + +class LambdaClientContext(object): + _client: LambdaClientContextMobileClient + _custom: Dict[str, Any] + _env: Dict[str, Any] + + @property + def client(self) -> LambdaClientContextMobileClient: + """Client context that's provided to Lambda by the client application.""" + return self._client + + @property + def custom(self) -> Dict[str, Any]: + """A dict of custom values set by the mobile client application.""" + return self._custom + + @property + def env(self) -> Dict[str, Any]: + """A dict of environment information provided by the AWS SDK.""" + return self._env diff --git a/aws_lambda_powertools/utilities/typing/lambda_client_context_mobile_client.py b/aws_lambda_powertools/utilities/typing/lambda_client_context_mobile_client.py new file mode 100644 index 00000000000..bd204891d2b --- /dev/null +++ b/aws_lambda_powertools/utilities/typing/lambda_client_context_mobile_client.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + + +class LambdaClientContextMobileClient(object): + """Mobile Client context that's provided to Lambda by the client application.""" + + _installation_id: str + _app_title: str + _app_version_name: str + _app_version_code: str + _app_package_name: str + + @property + def installation_id(self) -> str: + return self._installation_id + + @property + def app_title(self) -> str: + return self._app_title + + @property + def app_version_name(self) -> str: + return self._app_version_name + + @property + def app_version_code(self) -> str: + return self._app_version_code + + @property + def app_package_name(self) -> str: + return self._app_package_name diff --git a/aws_lambda_powertools/utilities/typing/lambda_cognito_identity.py b/aws_lambda_powertools/utilities/typing/lambda_cognito_identity.py new file mode 100644 index 00000000000..06679269330 --- /dev/null +++ b/aws_lambda_powertools/utilities/typing/lambda_cognito_identity.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- + + +class LambdaCognitoIdentity(object): + """Information about the Amazon Cognito identity that authorized the request.""" + + _cognito_identity_id: str + _cognito_identity_pool_id: str + + @property + def cognito_identity_id(self) -> str: + """The authenticated Amazon Cognito identity.""" + return self._cognito_identity_id + + @property + def cognito_identity_pool_id(self) -> str: + """The Amazon Cognito identity pool that authorized the invocation.""" + return self._cognito_identity_pool_id diff --git a/aws_lambda_powertools/utilities/typing/lambda_context.py b/aws_lambda_powertools/utilities/typing/lambda_context.py new file mode 100644 index 00000000000..b132fe413bc --- /dev/null +++ b/aws_lambda_powertools/utilities/typing/lambda_context.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +from aws_lambda_powertools.utilities.typing.lambda_client_context import LambdaClientContext +from aws_lambda_powertools.utilities.typing.lambda_cognito_identity import LambdaCognitoIdentity + + +class LambdaContext(object): + """The LambdaContext static object can be used to ease the development by providing the IDE type hints. + + Example + ------- + **A Lambda function using LambdaContext** + + >>> from typing import Any, Dict + >>> from aws_lambda_powertools.utilities.typing import LambdaContext + >>> + >>> def handler(event: Dict[str, Any], context: LambdaContext) -> Dict[str, Any]: + >>> # Insert business logic + >>> return event + + """ + + _function_name: str + _function_version: str + _invoked_function_arn: str + _memory_limit_in_mb: int + _aws_request_id: str + _log_group_name: str + _log_stream_name: str + _identity: LambdaCognitoIdentity + _client_context: LambdaClientContext + + @property + def function_name(self) -> str: + """The name of the Lambda function.""" + return self._function_name + + @property + def function_version(self) -> str: + """The version of the function.""" + return self._function_version + + @property + def invoked_function_arn(self) -> str: + """The Amazon Resource Name (ARN) that's used to invoke the function. Indicates if the invoker specified a + version number or alias.""" + return self._invoked_function_arn + + @property + def memory_limit_in_mb(self) -> int: + """The amount of memory that's allocated for the function.""" + return self._memory_limit_in_mb + + @property + def aws_request_id(self) -> str: + """The identifier of the invocation request.""" + return self._aws_request_id + + @property + def log_group_name(self) -> str: + """The log group for the function.""" + return self._log_group_name + + @property + def log_stream_name(self) -> str: + """The log stream for the function instance.""" + return self._log_stream_name + + @property + def identity(self) -> LambdaCognitoIdentity: + """(mobile apps) Information about the Amazon Cognito identity that authorized the request.""" + return self._identity + + @property + def client_context(self) -> LambdaClientContext: + """(mobile apps) Client context that's provided to Lambda by the client application.""" + return self._client_context + + @staticmethod + def get_remaining_time_in_millis() -> int: + """Returns the number of milliseconds left before the execution times out.""" + return 0 diff --git a/docs/content/core/logger.mdx b/docs/content/core/logger.mdx index 618d7a78435..0f8bb7fa9b9 100644 --- a/docs/content/core/logger.mdx +++ b/docs/content/core/logger.mdx @@ -47,7 +47,7 @@ Logger(service="payment", level="INFO") ## Standard structured keys -Your Logger will always include the following keys to your structured logging: +Your Logger will include the following keys to your structured logging, by default: Key | Type | Example | Description ------------------------------------------------- | ------------------------------------------------- | --------------------------------------------------------------------------------- | ------------------------------------------------- @@ -55,8 +55,9 @@ Key | Type | Example | Description **level** | str | "INFO" | Logging level **location** | str | "collect.handler:1" | Source code location where statement was executed **service** | str | "payment" | Service name defined. "service_undefined" will be used if unknown -**sampling_rate** | int | 0.1 | Debug logging sampling rate in percentage e.g. 1% in this case +**sampling_rate** | int | 0.1 | Debug logging sampling rate in percentage e.g. 10% in this case **message** | any | "Collecting payment" | Log statement value. Unserializable JSON values will be casted to string +**xray_trace_id** | str | "1-5759e988-bd862e3fe1be46a994272793" | X-Ray Trace ID when Lambda function has enabled Tracing ## Capturing Lambda context info @@ -90,7 +91,7 @@ Key | Type | Example **function_request_id**| str | "899856cb-83d1-40d7-8611-9e78f15f32f4"
-Exerpt output in CloudWatch Logs +Excerpt output in CloudWatch Logs ```json:title=cloudwatch_logs.json { @@ -164,9 +165,9 @@ def handler(event, context): ```
-Exerpt output in CloudWatch Logs +Excerpt output in CloudWatch Logs -```json:title=cloudwatch_logs.jsonn +```json:title=cloudwatch_logs.json { "timestamp": "2020-05-24 18:17:33,774", "level": "INFO", @@ -232,7 +233,7 @@ Sampling calculation happens at the Logger class initialization. This means, whe ```python:title=collect.py from aws_lambda_powertools import Logger -# Sample 1% of debug logs e.g. 0.1 +# Sample 10% of debug logs e.g. 0.1 logger = Logger(sample_rate=0.1) # highlight-line def handler(event, context): @@ -242,7 +243,7 @@ def handler(event, context): ```
-Exerpt output in CloudWatch Logs +Excerpt output in CloudWatch Logs ```json:title=cloudwatch_logs.json { @@ -260,3 +261,100 @@ def handler(event, context): } ```
+ + +## Migrating from other Loggers + +If you're migrating from other Loggers, there are few key points to be aware of: **Service parameter**, **Inheriting Loggers**, **Overriding Log records**, and **Logging exceptions**. + +### The service parameter + +Service is what defines what the function is responsible for, or part of (e.g payment service), and the name of the Logger. + +For Logger, the `service` is the logging key customers can use to search log operations for one or more functions - For example, **search for all errors, or messages like X, where service is payment**. + +### Inheriting Loggers + +> Python Logging hierarchy happens via the dot notation: `service`, `service.child`, `service.child_2`. + +For inheritance, Logger uses a `child=True` parameter along with `service` being the same value across Loggers. + +For child Loggers, we introspect the name of your module where `Logger(child=True, service="name")` is called, and we name your Logger as **{service}.{filename}**. + +A common issue when migrating from other Loggers is that `service` might be defined in the parent Logger (no child param), and not defined in the child Logger: + +```python:title=incorrect_logger_inheritance.py +# app.py +import my_module +from aws_lambda_powertools import Logger + +logger = Logger(service="payment") # highlight-line +... + +# my_module.py +from aws_lambda_powertools import Logger + +logger = Logger(child=True) # highlight-line +``` + +In this case, Logger will register a Logger named `payment`, and a Logger named `service_undefined`. The latter isn't inheriting from the parent, and will have no handler, thus no message being logged to standard output. + +This can be fixed by either ensuring both has the `service` value as `payment`, or simply use the environment variable `POWERTOOLS_SERVICE_NAME` to ensure service value will be the same across all Loggers when not explicitly set. + +### Overriding Log records + +You might want to continue to use the same date formatting style, or override `location` to display the `package.function_name:line_number` as you previously had. + +Logger allows you to either change the format or suppress the following keys altogether at the initialization: `location`, `timestamp`, `level`, and `datefmt` + +```python +from aws_lambda_powertools import Logger + +# override default values for location and timestamp format +logger = Logger(stream=stdout, location="[%(funcName)s] %(module)s", datefmt="fake-datefmt") # highlight-line + +# suppress location key +logger = Logger(stream=stdout, location=None) # highlight-line +``` + +Alternatively, you can also change the order of the following log record keys via the `log_record_order` parameter: `level`, `location`, `message`, and `timestamp` + +```python +from aws_lambda_powertools import Logger + +# make message as the first key +logger = Logger(stream=stdout, log_record_order=["message"]) # highlight-line + +# Default key sorting order +logger = Logger(stream=stdout, log_record_order=["level","location","message","timestamp"]) # highlight-line +``` + +### Logging exceptions + +When logging exceptions, Logger will add a new key named `exception`, and will serialize the full traceback as a string. + +```python:title=logging_an_exception.py +from aws_lambda_powertools import Logger +logger = Logger() + +try: + raise ValueError("something went wrong") +except Exception: + logger.exception("Received an exception") # highlight-line +``` + +
+Excerpt output in CloudWatch Logs + +```json:title=cloudwatch_logs.json +{ + "level": "ERROR", + "location": ":4", + "message": "Received an exception", + "timestamp": "2020-08-28 18:11:38,886", + "service": "service_undefined", + "sampling_rate": 0.0, + "exception": "Traceback (most recent call last):\n File \"\", line 2, in \nValueError: something went wrong" +} +``` +
diff --git a/docs/content/core/metrics.mdx b/docs/content/core/metrics.mdx index 9b750c28622..be1b9feaa5b 100644 --- a/docs/content/core/metrics.mdx +++ b/docs/content/core/metrics.mdx @@ -113,7 +113,7 @@ metrics.add_metadata(key="booking_id", value="booking_uuid") # highlight-line This will be available in CloudWatch Logs to ease operations on high cardinal data.
-Exerpt output in CloudWatch Logs +Excerpt output in CloudWatch Logs ```json:title=cloudwatch_logs.json { diff --git a/docs/content/index.mdx b/docs/content/index.mdx index 26ab367ba4c..ec2dd862e38 100644 --- a/docs/content/index.mdx +++ b/docs/content/index.mdx @@ -5,7 +5,7 @@ description: AWS Lambda Powertools Python import Note from "../src/components/Note" -Powertools is a suite of utilities for AWS Lambda Functions that makes tracing with AWS X-Ray, structured logging and creating custom metrics asynchronously easier. +Powertools is a suite of utilities for AWS Lambda functions that makes tracing with AWS X-Ray, structured logging and creating custom metrics asynchronously easier. Looking for a quick run through of the core utilities?

@@ -24,6 +24,12 @@ Powertools is available in PyPi. You can use your favourite dependency managemen ```bash:title=hello_world.sh sam init --location https://github.com/aws-samples/cookiecutter-aws-sam-python ``` +* [Tracing](./core/tracer) - Decorators and utilities to trace Lambda function handlers, and both synchronous and asynchronous functions +* [Logging](./core/logger) - Structured logging made easier, and decorator to enrich structured logging with key Lambda context details +* [Metrics](./core/metrics) - Custom Metrics created asynchronously via CloudWatch Embedded Metric Format (EMF) +* [Bring your own middleware](./utilities/middleware_factory) - Decorator factory to create your own middleware to run logic before, and after each Lambda invocation +* [Parameters utility](./utilities/parameters) - Retrieve parameter values from AWS Systems Manager Parameter Store, AWS Secrets Manager, or Amazon DynamoDB, and cache them for a specific amount of time +* [Batch utility](./utilities/batch) - Batch processing for AWS SQS, handles partial failure. ### Lambda Layer @@ -44,6 +50,19 @@ If using SAM, you can include this SAR App as part of your shared Layers stack, SemanticVersion: 1.3.1 # change to latest semantic version available in SAR ``` +This will add a nested app stack with an output parameter `LayerVersionArn`, that you can reference inside your Lambda function definition: + +```yaml + Layers: + - !GetAtt AwsLambdaPowertoolsPythonLayer.Outputs.LayerVersionArn +``` + +You can fetch the available versions via the API with: + + ```bash + aws serverlessrepo list-application-versions --application-id arn:aws:serverlessrepo:eu-west-1:057560766410:applications/aws-lambda-powertools-python-layer + ``` + ## Features Utility | Description @@ -53,6 +72,7 @@ Utility | Description [Metrics](./core/metrics) | Custom Metrics created asynchronously via CloudWatch Embedded Metric Format (EMF) [Bring your own middleware](.//utilities/middleware_factory) | Decorator factory to create your own middleware to run logic before, and after each Lambda invocation [Parameters utility](./utilities/parameters) | Retrieve parameter values from AWS Systems Manager Parameter Store, AWS Secrets Manager, or Amazon DynamoDB, and cache them for a specific amount of time +[Typing utility](./utilities/typing) | Static typing classes to speedup development in your IDE ## Environment variables diff --git a/docs/content/media/utilities_typing.png b/docs/content/media/utilities_typing.png new file mode 100644 index 00000000000..0f293abb6ec Binary files /dev/null and b/docs/content/media/utilities_typing.png differ diff --git a/docs/content/utilities/batch.mdx b/docs/content/utilities/batch.mdx new file mode 100644 index 00000000000..e8fe73dc4ff --- /dev/null +++ b/docs/content/utilities/batch.mdx @@ -0,0 +1,257 @@ +--- +title: SQS Batch Processing +description: Utility +--- + +import Note from "../../src/components/Note" + +The SQS batch processing utility provides a way to handle partial failures when processing batches of messages from SQS. + +**Key Features** + +* Prevent successfully processed messages being returned to SQS +* Simple interface for individually processing messages from a batch +* Build your own batch processor using the base classes + +**Background** + +When using SQS as a Lambda event source mapping, Lambda functions are triggered with a batch of messages from SQS. + +If your function fails to process any message from the batch, the entire batch returns to your SQS queue, and your Lambda function is triggered with the same batch one more time. + +With this utility, messages within a batch are handled individually - only messages that were not successfully processed +are returned to the queue. + + + While this utility lowers the chance of processing messages more than once, it is not guaranteed. We recommend implementing processing logic in an idempotent manner wherever possible. +

+ More details on how Lambda works with SQS can be found in the AWS documentation +

+ + +**IAM Permissions** + +This utility requires additional permissions to work as expected. Lambda functions using this utility require the `sqs:DeleteMessageBatch` permission. + +## Processing messages from SQS + +You can use either **[sqs_batch_processor](#sqs_batch_processor-decorator)** decorator, or **[PartialSQSProcessor](#partialsqsprocessor-context-manager)** as a context manager. + +They have nearly the same behaviour when it comes to processing messages from the batch: + +* **Entire batch has been successfully processed**, where your Lambda handler returned successfully, we will let SQS delete the batch to optimize your cost +* **Entire Batch has been partially processed successfully**, where exceptions were raised within your `record handler`, we will: + - **1)** Delete successfully processed messages from the queue by directly calling `sqs:DeleteMessageBatch` + - **2)** Raise `SQSBatchProcessingError` to ensure failed messages return to your SQS queue + +The only difference is that **PartialSQSProcessor** will give you access to processed messages if you need. + +## Record Handler + +Both decorator and context managers require an explicit function to process the batch of messages - namely `record_handler` parameter. + +This function is responsible for processing each individual message from the batch, and to raise an exception if unable to process any of the messages sent. + +**Any non-exception/successful return from your record handler function** will instruct both decorator and context manager to queue up each individual message for deletion. + +### sqs_batch_processor decorator + +When using the this decorator, you need provide a function via `record_handler` param that will process individual messages from the batch - It should raise an exception if it is unable to process the record. + +All records in the batch will be passed to this handler for processing, even if exceptions are thrown - Here's the behaviour after completing the batch: + +* **Any successfully processed messages**, we will delete them from the queue via `sqs:DeleteMessageBatch` +* **Any unprocessed messages detected**, we will raise `SQSBatchProcessingError` to ensure failed messages return to your SQS queue + + + You will not have accessed to the processed messages within the Lambda Handler - all processing logic will and should be performed by the record_handler function. +
+ +```python:title=app.py +from aws_lambda_powertools.utilities.batch import sqs_batch_processor + +def record_handler(record): + # This will be called for each individual message from a batch + # It should raise an exception if the message was not processed successfully + return_value = do_something_with(record["body"]) + return return_value + +@sqs_batch_processor(record_handler=record_handler) +def lambda_handler(event, context): + return {"statusCode": 200} +``` + +### PartialSQSProcessor context manager + +If you require access to the result of processed messages, you can use this context manager. + +The result from calling `process()` on the context manager will be a list of all the return values from your `record_handler` function. + +```python:title=app.py +from aws_lambda_powertools.utilities.batch import PartialSQSProcessor + +def record_handler(record): + # This will be called for each individual message from a batch + # It should raise an exception if the message was not processed successfully + return_value = do_something_with(record["body"]) + return return_value + + +def lambda_handler(event, context): + records = event["Records"] + + processor = PartialSQSProcessor() + + with processor(records, record_handler) as proc: + result = proc.process() # Returns a list of all results from record_handler + + return result +``` + +## Passing custom boto3 config + +If you need to pass custom configuration such as region to the SDK, you can pass your own [botocore config object](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html) to +the `sqs_batch_processor` decorator: + +```python:title=app.py +from aws_lambda_powertools.utilities.batch import sqs_batch_processor +from botocore.config import Config + +config = Config(region_name="us-east-1") # highlight-line + +def record_handler(record): + # This will be called for each individual message from a batch + # It should raise an exception if the message was not processed successfully + return_value = do_something_with(record["body"]) + return return_value + +@sqs_batch_processor(record_handler=record_handler, config=config) # highlight-line +def lambda_handler(event, context): + return {"statusCode": 200} +``` + +Or to the `PartialSQSProcessor` class: +```python:title=app.py +from aws_lambda_powertools.utilities.batch import PartialSQSProcessor + +from botocore.config import Config + +config = Config(region_name="us-east-1") # highlight-line + +def record_handler(record): + # This will be called for each individual message from a batch + # It should raise an exception if the message was not processed successfully + return_value = do_something_with(record["body"]) + return return_value + + +def lambda_handler(event, context): + records = event["Records"] + + processor = PartialSQSProcessor(config=config) # highlight-line + + with processor(records, record_handler): + result = processor.process() + + return result +``` + + +## Suppressing exceptions + +If you want to disable the default behavior where `SQSBatchProcessingError` is raised if there are any errors, you can pass the `suppress_exception` boolean argument. + +**Within the decorator** + +```python:title=app.py +... +@sqs_batch_processor(record_handler=record_handler, config=config, suppress_exception=True) # highlight-line +def lambda_handler(event, context): + return {"statusCode": 200} +``` + +**Within the context manager** + +```python:title=app.py +processor = PartialSQSProcessor(config=config, suppress_exception=True) # highlight-line + +with processor(records, record_handler): + result = processor.process() +``` + +## Create your own partial processor + +You can create your own partial batch processor by inheriting the `BasePartialProcessor` class, and implementing `_prepare()`, `_clean()` and `_process_record()`. + +* **`_process_record()`** - Handles all processing logic for each individual message of a batch, including calling the `record_handler` (self.handler) +* **`_prepare()`** - Called once as part of the processor initialization +* **`clean()`** - Teardown logic called once after `_process_record` completes + +You can then use this class as a context manager, or pass it to `batch_processor` to use as a decorator on your Lambda handler function. + +**Example:** + +```python:title=custom_processor.py +from random import randint + +from aws_lambda_powertools.utilities.batch import BasePartialProcessor, batch_processor +import boto3 +import os + +table_name = os.getenv("TABLE_NAME", "table_not_found") + +class MyPartialProcessor(BasePartialProcessor): + """ + Process a record and stores successful results at a Amazon DynamoDB Table + + Parameters + ---------- + table_name: str + DynamoDB table name to write results to + """ + + def __init__(self, table_name: str): + self.table_name = table_name + + super().__init__() + + def _prepare(self): + # It's called once, *before* processing + # Creates table resource and clean previous results + # E.g.: + self.ddb_table = boto3.resource("dynamodb").Table(self.table_name) + self.success_messages.clear() + + def _clean(self): + # It's called once, *after* closing processing all records (closing the context manager) + # Here we're sending, at once, all successful messages to a ddb table + # E.g.: + with ddb_table.batch_writer() as batch: + for result in self.success_messages: + batch.put_item(Item=result) + + def _process_record(self, record): + # It handles how your record is processed + # Here we're keeping the status of each run + # where self.handler is the record_handler function passed as an argument + # E.g.: + try: + result = self.handler(record) # record_handler passed to decorator/context manager + return self.success_handler(record, result) + except Exception as exc: + return self.failure_handler(record, exc) + + def success_handler(self, record): + entry = ("success", result, record) + message = {"age": result} + self.success_messages.append(message) + return entry + + +def record_handler(record): + return randint(0, 100) + +@batch_processor(record_handler=record_handler, processor=MyPartialProcessor(table_name)) +def lambda_handler(event, context): + return {"statusCode": 200} +``` diff --git a/docs/content/utilities/typing.mdx b/docs/content/utilities/typing.mdx new file mode 100644 index 00000000000..9192b095887 --- /dev/null +++ b/docs/content/utilities/typing.mdx @@ -0,0 +1,25 @@ +--- +title: Typing +description: Utility +--- + +import Note from "../../src/components/Note" + +This typing utility provides static typing classes that can be used to ease the development by providing the IDE type hints. + +![Utilities Typing](../media/utilities_typing.png) + +## LambdaContext + +The `LambdaContext` typing is typically used in the handler method for the Lambda function. + +```python:title=index.py +from typing import Any, Dict +from aws_lambda_powertools.utilities.typing import LambdaContext + +# highlight-start +def handler(event: Dict[str, Any], context: LambdaContext) -> Dict[str, Any]: +# highlight-end + # Insert business logic + return event +``` diff --git a/docs/gatsby-config.js b/docs/gatsby-config.js index d518ee8e715..a4286e0d55f 100644 --- a/docs/gatsby-config.js +++ b/docs/gatsby-config.js @@ -4,7 +4,7 @@ module.exports = { pathPrefix: '/aws-lambda-powertools-python', siteMetadata: { title: 'AWS Lambda Powertools Python', - description: 'A suite of utilities for AWS Lambda Functions that makes tracing with AWS X-Ray, structured logging and creating custom metrics asynchronously easier', + description: 'A suite of utilities for AWS Lambda functions that makes tracing with AWS X-Ray, structured logging and creating custom metrics asynchronously easier', author: `Amazon Web Services`, siteName: 'AWS Lambda Powertools Python', siteUrl: `${docsWebsite}` @@ -32,6 +32,8 @@ module.exports = { 'Utilities': [ 'utilities/middleware_factory', 'utilities/parameters', + 'utilities/batch', + 'utilities/typing', ], }, navConfig: { diff --git a/docs/package.json b/docs/package.json index b91de202b8e..592587b2de8 100644 --- a/docs/package.json +++ b/docs/package.json @@ -20,6 +20,6 @@ "license": "MIT-0", "repository": "https://github.com/awslabs/aws-lambda-powertools-python", "name": "aws-lambda-powertools-python", - "description": "Powertools is a suite of utilities for AWS Lambda Functions that makes tracing with AWS X-Ray, structured logging and creating custom metrics asynchronously easier.", + "description": "Powertools is a suite of utilities for AWS Lambda functions that makes tracing with AWS X-Ray, structured logging and creating custom metrics asynchronously easier.", "devDependencies": {} } diff --git a/example/README.md b/example/README.md index a54ab6ba3c3..b19f98f14f1 100644 --- a/example/README.md +++ b/example/README.md @@ -55,7 +55,7 @@ The first command will build the source of your application. The second command * **Stack Name**: The name of the stack to deploy to CloudFormation. This should be unique to your account and region, and a good starting point would be something matching your project name. * **AWS Region**: The AWS region you want to deploy your app to. * **Confirm changes before deploy**: If set to yes, any change sets will be shown to you before execution for manual review. If set to no, the AWS SAM CLI will automatically deploy application changes. -* **Allow SAM CLI IAM role creation**: Many AWS SAM templates, including this example, create AWS IAM roles required for the AWS Lambda function(s) included to access AWS services. By default, these are scoped down to minimum required permissions. To deploy an AWS CloudFormation stack which creates or modified IAM roles, the `CAPABILITY_IAM` value for `capabilities` must be provided. If permission isn't provided through this prompt, to deploy this example you must explicitly pass `--capabilities CAPABILITY_IAM` to the `sam deploy` command. +* **Allow SAM CLI IAM role creation**: Many AWS SAM templates, including this example, create AWS IAM roles required for the AWS Lambda function(s) included to access AWS services. By default, these are scoped down to minimum required permissions. To deploy an AWS CloudFormation stack which creates or modified IAM roles, the `CAPABILITY_IAM` value for `capabilities` must be provided. If permission isn't provided through this prompt, to deploy this example you must explicitly pass `--capabilities CAPABILITY_IAM` to the `sam deploy` command. If you are using `AWS::Serverless::Application` as a layer, you need also to pass `CAPABILITY_AUTO_EXPAND` during `sam deploy`, because it will create a nested stack for the layer. * **Save arguments to samconfig.toml**: If set to yes, your choices will be saved to a configuration file inside the project, so that in the future you can just re-run `sam deploy` without parameters to deploy changes to your application. You can find your API Gateway Endpoint URL in the output values displayed after deployment. diff --git a/example/hello_world/app.py b/example/hello_world/app.py index 3d47646c389..618e9a65764 100644 --- a/example/hello_world/app.py +++ b/example/hello_world/app.py @@ -14,7 +14,7 @@ set_package_logger() # Enable package diagnostics (DEBUG log) # tracer = Tracer() # patches all available modules # noqa: E800 -tracer = Tracer(patch_modules=("aioboto3", "boto3", "requests")) # ~90-100ms faster in perf depending on set of libs +tracer = Tracer(patch_modules=["aioboto3", "boto3", "requests"]) # ~90-100ms faster in perf depending on set of libs logger = Logger() metrics = Metrics() @@ -114,13 +114,13 @@ def lambda_handler(event, context): try: ip = requests.get("http://checkip.amazonaws.com/") - metrics.add_metric(name="SuccessfulLocations", unit="Count", value=1) + metrics.add_metric(name="SuccessfulLocations", unit=MetricUnit.Count, value=1) except requests.RequestException as e: # Send some context about this error to Lambda Logs logger.exception(e) raise - with single_metric(name="UniqueMetricDimension", unit="Seconds", value=1) as metric: + with single_metric(name="UniqueMetricDimension", unit=MetricUnit.Seconds, value=1) as metric: metric.add_dimension(name="unique_dimension", value="for_unique_metric") resp = {"message": "hello world", "location": ip.text.replace("\n", ""), "async_http": async_http_ret} diff --git a/example/template.yaml b/example/template.yaml index 47267d729f5..809e981edc0 100644 --- a/example/template.yaml +++ b/example/template.yaml @@ -17,6 +17,8 @@ Resources: CodeUri: hello_world/ Handler: app.lambda_handler Runtime: python3.8 + Layers: + - !GetAtt AwsLambdaPowertoolsPythonLayer.Outputs.LayerVersionArn Tracing: Active # enables X-Ray tracing Environment: Variables: @@ -33,6 +35,13 @@ Resources: Path: /hello Method: get + AwsLambdaPowertoolsPythonLayer: + Type: AWS::Serverless::Application + Properties: + Location: + ApplicationId: arn:aws:serverlessrepo:eu-west-1:057560766410:applications/aws-lambda-powertools-python-layer + SemanticVersion: 1.3.1 # change to latest semantic version available in SAR + Outputs: # ServerlessRestApi is an implicit API created out of Events key under Serverless::Function # Find out more about other implicit resources you can reference within SAM diff --git a/pyproject.toml b/pyproject.toml index 240ae4ed84d..1fc075e8382 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws_lambda_powertools" -version = "1.4.0" +version = "1.5.0" description = "Python utilities for AWS Lambda functions including but not limited to tracing, logging and custom metric" authors = ["Amazon Web Services"] classifiers=[ diff --git a/tests/functional/test_aws_lambda_logging.py b/tests/functional/test_aws_lambda_logging.py index cf4782d1d2a..79667f907fc 100644 --- a/tests/functional/test_aws_lambda_logging.py +++ b/tests/functional/test_aws_lambda_logging.py @@ -38,16 +38,17 @@ def test_setup_with_valid_log_levels(stdout, level): def test_logging_exception_traceback(stdout): - logger = Logger(level="DEBUG", stream=stdout, request_id="request id!", another="value") + logger = Logger(level="DEBUG", stream=stdout) try: - raise Exception("Boom") - except Exception: - logger.exception("This is a test") + raise ValueError("Boom") + except ValueError: + logger.exception("A value error occurred") log_dict = json.loads(stdout.getvalue()) check_log_dict(log_dict) + assert "ERROR" == log_dict["level"] assert "exception" in log_dict @@ -86,15 +87,143 @@ def test_with_json_message(stdout): assert msg == log_dict["message"] -def test_with_unserialisable_value_in_message(stdout): +def test_with_unserializable_value_in_message(stdout): logger = Logger(level="DEBUG", stream=stdout) - class X: + class Unserializable: pass - msg = {"x": X()} + msg = {"x": Unserializable()} logger.debug(msg) log_dict = json.loads(stdout.getvalue()) assert log_dict["message"]["x"].startswith("<") + + +def test_with_unserializable_value_in_message_custom(stdout): + class Unserializable: + pass + + # GIVEN a custom json_default + logger = Logger(level="DEBUG", stream=stdout, json_default=lambda o: f"") + + # WHEN we log a message + logger.debug({"x": Unserializable()}) + + log_dict = json.loads(stdout.getvalue()) + + # THEN json_default should not be in the log message and the custom unserializable handler should be used + assert log_dict["message"]["x"] == "" + assert "json_default" not in log_dict + + +def test_log_dict_key_seq(stdout): + # GIVEN the default logger configuration + logger = Logger(stream=stdout) + + # WHEN logging a message + logger.info("Message") + + log_dict: dict = json.loads(stdout.getvalue()) + + # THEN the beginning key sequence must be `level,location,message,timestamp` + assert ",".join(list(log_dict.keys())[:4]) == "level,location,message,timestamp" + + +def test_log_dict_key_custom_seq(stdout): + # GIVEN a logger configuration with log_record_order set to ["message"] + logger = Logger(stream=stdout, log_record_order=["message"]) + + # WHEN logging a message + logger.info("Message") + + log_dict: dict = json.loads(stdout.getvalue()) + + # THEN the first key should be "message" + assert list(log_dict.keys())[0] == "message" + + +def test_log_custom_formatting(stdout): + # GIVEN a logger where we have a custom `location`, 'datefmt' format + logger = Logger(stream=stdout, location="[%(funcName)s] %(module)s", datefmt="fake-datefmt") + + # WHEN logging a message + logger.info("foo") + + log_dict: dict = json.loads(stdout.getvalue()) + + # THEN the `location` and "timestamp" should match the formatting + assert log_dict["location"] == "[test_log_custom_formatting] test_aws_lambda_logging" + assert log_dict["timestamp"] == "fake-datefmt" + + +def test_log_dict_key_strip_nones(stdout): + # GIVEN a logger confirmation where we set `location` and `timestamp` to None + # Note: level, sampling_rate and service can not be suppressed + logger = Logger(stream=stdout, level=None, location=None, timestamp=None, sampling_rate=None, service=None) + + # WHEN logging a message + logger.info("foo") + + log_dict: dict = json.loads(stdout.getvalue()) + + # THEN the keys should only include `level`, `message`, `service`, `sampling_rate` + assert sorted(log_dict.keys()) == ["level", "message", "sampling_rate", "service"] + + +def test_log_dict_xray_is_present_when_tracing_is_enabled(stdout, monkeypatch): + # GIVEN a logger is initialized within a Lambda function with X-Ray enabled + trace_id = "1-5759e988-bd862e3fe1be46a994272793" + trace_header = f"Root={trace_id};Parent=53995c3f42cd8ad8;Sampled=1" + monkeypatch.setenv(name="_X_AMZN_TRACE_ID", value=trace_header) + logger = Logger(stream=stdout) + + # WHEN logging a message + logger.info("foo") + + log_dict: dict = json.loads(stdout.getvalue()) + + # THEN `xray_trace_id`` key should be present + assert log_dict["xray_trace_id"] == trace_id + + monkeypatch.delenv(name="_X_AMZN_TRACE_ID") + + +def test_log_dict_xray_is_not_present_when_tracing_is_disabled(stdout, monkeypatch): + # GIVEN a logger is initialized within a Lambda function with X-Ray disabled (default) + logger = Logger(stream=stdout) + + # WHEN logging a message + logger.info("foo") + + log_dict: dict = json.loads(stdout.getvalue()) + + # THEN `xray_trace_id`` key should not be present + assert "xray_trace_id" not in log_dict + + +def test_log_dict_xray_is_updated_when_tracing_id_changes(stdout, monkeypatch): + # GIVEN a logger is initialized within a Lambda function with X-Ray enabled + trace_id = "1-5759e988-bd862e3fe1be46a994272793" + trace_header = f"Root={trace_id};Parent=53995c3f42cd8ad8;Sampled=1" + monkeypatch.setenv(name="_X_AMZN_TRACE_ID", value=trace_header) + logger = Logger(stream=stdout) + + # WHEN logging a message + logger.info("foo") + + # and Trace ID changes to mimick a new invocation + trace_id_2 = "1-5759e988-bd862e3fe1be46a949393982437" + trace_header_2 = f"Root={trace_id_2};Parent=53995c3f42cd8ad8;Sampled=1" + monkeypatch.setenv(name="_X_AMZN_TRACE_ID", value=trace_header_2) + + logger.info("foo bar") + + log_dict, log_dict_2 = [json.loads(line.strip()) for line in stdout.getvalue().split("\n") if line] + + # THEN `xray_trace_id`` key should be different in both invocations + assert log_dict["xray_trace_id"] == trace_id + assert log_dict_2["xray_trace_id"] == trace_id_2 + + monkeypatch.delenv(name="_X_AMZN_TRACE_ID") diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py new file mode 100644 index 00000000000..f56a172637a --- /dev/null +++ b/tests/functional/test_utilities_batch.py @@ -0,0 +1,277 @@ +from typing import Callable +from unittest.mock import patch + +import pytest +from botocore.config import Config +from botocore.stub import Stubber + +from aws_lambda_powertools.utilities.batch import PartialSQSProcessor, batch_processor, sqs_batch_processor +from aws_lambda_powertools.utilities.batch.exceptions import SQSBatchProcessingError + + +@pytest.fixture(scope="module") +def sqs_event_factory() -> Callable: + def factory(body: str): + return { + "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", + "body": body, + "attributes": {}, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-1", + } + + return factory + + +@pytest.fixture(scope="module") +def record_handler() -> Callable: + def handler(record): + body = record["body"] + if "fail" in body: + raise Exception("Failed to process record.") + return body + + return handler + + +@pytest.fixture(scope="module") +def config() -> Config: + return Config(region_name="us-east-1") + + +@pytest.fixture(scope="function") +def partial_processor(config) -> PartialSQSProcessor: + return PartialSQSProcessor(config=config) + + +@pytest.fixture(scope="function") +def partial_processor_suppressed(config) -> PartialSQSProcessor: + return PartialSQSProcessor(config=config, suppress_exception=True) + + +@pytest.fixture(scope="function") +def stubbed_partial_processor(config) -> PartialSQSProcessor: + processor = PartialSQSProcessor(config=config) + with Stubber(processor.client) as stubber: + yield stubber, processor + + +@pytest.fixture(scope="function") +def stubbed_partial_processor_suppressed(config) -> PartialSQSProcessor: + processor = PartialSQSProcessor(config=config, suppress_exception=True) + with Stubber(processor.client) as stubber: + yield stubber, processor + + +def test_partial_sqs_processor_context_with_failure(sqs_event_factory, record_handler, partial_processor): + """ + Test processor with one failing record + """ + fail_record = sqs_event_factory("fail") + success_record = sqs_event_factory("success") + + records = [fail_record, success_record] + + response = {"Successful": [{"Id": fail_record["messageId"]}], "Failed": []} + + with Stubber(partial_processor.client) as stubber: + stubber.add_response("delete_message_batch", response) + + with pytest.raises(SQSBatchProcessingError) as error: + with partial_processor(records, record_handler) as ctx: + ctx.process() + + assert len(error.value.args[0]) == 1 + stubber.assert_no_pending_responses() + + +def test_partial_sqs_processor_context_only_success(sqs_event_factory, record_handler, partial_processor): + """ + Test processor without failure + """ + first_record = sqs_event_factory("success") + second_record = sqs_event_factory("success") + + records = [first_record, second_record] + + with partial_processor(records, record_handler) as ctx: + result = ctx.process() + + assert result == [ + ("success", first_record["body"], first_record), + ("success", second_record["body"], second_record), + ] + + +def test_partial_sqs_processor_context_multiple_calls(sqs_event_factory, record_handler, partial_processor): + """ + Test processor without failure + """ + first_record = sqs_event_factory("success") + second_record = sqs_event_factory("success") + + records = [first_record, second_record] + + with partial_processor(records, record_handler) as ctx: + ctx.process() + + with partial_processor([first_record], record_handler) as ctx: + ctx.process() + + assert partial_processor.success_messages == [first_record] + + +def test_batch_processor_middleware_with_partial_sqs_processor(sqs_event_factory, record_handler, partial_processor): + """ + Test middleware's integration with PartialSQSProcessor + """ + + @batch_processor(record_handler=record_handler, processor=partial_processor) + def lambda_handler(event, context): + return True + + fail_record = sqs_event_factory("fail") + + event = {"Records": [sqs_event_factory("fail"), sqs_event_factory("fail"), sqs_event_factory("success")]} + response = {"Successful": [{"Id": fail_record["messageId"]}], "Failed": []} + + with Stubber(partial_processor.client) as stubber: + stubber.add_response("delete_message_batch", response) + with pytest.raises(SQSBatchProcessingError) as error: + lambda_handler(event, {}) + + assert len(error.value.args[0]) == 2 + stubber.assert_no_pending_responses() + + +@patch("aws_lambda_powertools.utilities.batch.sqs.PartialSQSProcessor") +def test_sqs_batch_processor_middleware( + patched_sqs_processor, sqs_event_factory, record_handler, stubbed_partial_processor +): + """ + Test middleware's integration with PartialSQSProcessor + """ + + @sqs_batch_processor(record_handler=record_handler) + def lambda_handler(event, context): + return True + + stubber, processor = stubbed_partial_processor + patched_sqs_processor.return_value = processor + + fail_record = sqs_event_factory("fail") + + event = {"Records": [sqs_event_factory("fail"), sqs_event_factory("success")]} + response = {"Successful": [{"Id": fail_record["messageId"]}], "Failed": []} + stubber.add_response("delete_message_batch", response) + with pytest.raises(SQSBatchProcessingError) as error: + lambda_handler(event, {}) + + assert len(error.value.args[0]) == 1 + stubber.assert_no_pending_responses() + + +def test_batch_processor_middleware_with_custom_processor(capsys, sqs_event_factory, record_handler, config): + """ + Test middlewares' integration with custom batch processor + """ + + class CustomProcessor(PartialSQSProcessor): + def failure_handler(self, record, exception): + print("Oh no ! It's a failure.") + return super().failure_handler(record, exception) + + processor = CustomProcessor(config=config) + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return True + + fail_record = sqs_event_factory("fail") + + event = {"Records": [sqs_event_factory("fail"), sqs_event_factory("success")]} + response = {"Successful": [{"Id": fail_record["messageId"]}], "Failed": []} + + with Stubber(processor.client) as stubber: + stubber.add_response("delete_message_batch", response) + with pytest.raises(SQSBatchProcessingError) as error: + lambda_handler(event, {}) + + stubber.assert_no_pending_responses() + + assert len(error.value.args[0]) == 1 + assert capsys.readouterr().out == "Oh no ! It's a failure.\n" + + +def test_batch_processor_middleware_suppressed_exceptions( + sqs_event_factory, record_handler, partial_processor_suppressed +): + """ + Test middleware's integration with PartialSQSProcessor + """ + + @batch_processor(record_handler=record_handler, processor=partial_processor_suppressed) + def lambda_handler(event, context): + return True + + fail_record = sqs_event_factory("fail") + + event = {"Records": [sqs_event_factory("fail"), sqs_event_factory("fail"), sqs_event_factory("success")]} + response = {"Successful": [{"Id": fail_record["messageId"]}], "Failed": []} + + with Stubber(partial_processor_suppressed.client) as stubber: + stubber.add_response("delete_message_batch", response) + result = lambda_handler(event, {}) + + stubber.assert_no_pending_responses() + assert result is True + + +def test_partial_sqs_processor_suppressed_exceptions(sqs_event_factory, record_handler, partial_processor_suppressed): + """ + Test processor without failure + """ + + first_record = sqs_event_factory("success") + second_record = sqs_event_factory("fail") + records = [first_record, second_record] + + fail_record = sqs_event_factory("fail") + response = {"Successful": [{"Id": fail_record["messageId"]}], "Failed": []} + + with Stubber(partial_processor_suppressed.client) as stubber: + stubber.add_response("delete_message_batch", response) + with partial_processor_suppressed(records, record_handler) as ctx: + ctx.process() + + assert partial_processor_suppressed.success_messages == [first_record] + + +@patch("aws_lambda_powertools.utilities.batch.sqs.PartialSQSProcessor") +def test_sqs_batch_processor_middleware_suppressed_exception( + patched_sqs_processor, sqs_event_factory, record_handler, stubbed_partial_processor_suppressed +): + """ + Test middleware's integration with PartialSQSProcessor + """ + + @sqs_batch_processor(record_handler=record_handler) + def lambda_handler(event, context): + return True + + stubber, processor = stubbed_partial_processor_suppressed + patched_sqs_processor.return_value = processor + + fail_record = sqs_event_factory("fail") + + event = {"Records": [sqs_event_factory("fail"), sqs_event_factory("success")]} + response = {"Successful": [{"Id": fail_record["messageId"]}], "Failed": []} + stubber.add_response("delete_message_batch", response) + result = lambda_handler(event, {}) + + stubber.assert_no_pending_responses() + assert result is True diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index abd121540a6..55f643924ad 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -233,6 +233,48 @@ def test_dynamodb_provider_get_multiple(mock_name, mock_value, config): stubber.deactivate() +def test_dynamodb_provider_get_multiple_auto(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get_multiple() with transform = "auto" + """ + mock_binary = mock_value.encode() + mock_binary_data = base64.b64encode(mock_binary).decode() + mock_json_data = json.dumps({mock_name: mock_value}) + mock_params = {"D.json": mock_json_data, "E.binary": mock_binary_data, "F": mock_value} + table_name = "TEST_TABLE_AUTO" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + response = { + "Items": [ + {"id": {"S": mock_name}, "sk": {"S": name}, "value": {"S": value}} for (name, value) in mock_params.items() + ] + } + expected_params = {"TableName": table_name, "KeyConditionExpression": Key("id").eq(mock_name)} + stubber.add_response("query", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple(mock_name, transform="auto") + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_params) + for key in mock_params.keys(): + assert key in values + if key.endswith(".json"): + assert values[key][mock_name] == mock_value + elif key.endswith(".binary"): + assert values[key] == mock_binary + else: + assert values[key] == mock_value + finally: + stubber.deactivate() + + def test_dynamodb_provider_get_multiple_next_token(mock_name, mock_value, config): """ Test DynamoDBProvider.get_multiple() with a non-cached path @@ -1481,3 +1523,34 @@ def test_transform_value_ignore_error(mock_value): value = parameters.base.transform_value(mock_value, "INCORRECT", raise_on_transform_error=False) assert value is None + + +@pytest.mark.parametrize("original_transform", ["json", "binary", "other", "Auto", None]) +def test_get_transform_method_preserve_original(original_transform): + """ + Check if original transform method is returned for anything other than "auto" + """ + transform = parameters.base.get_transform_method("key", original_transform) + + assert transform == original_transform + + +@pytest.mark.parametrize("extension", ["json", "binary"]) +def test_get_transform_method_preserve_auto(extension, mock_name): + """ + Check if we can auto detect the transform method by the support extensions json / binary + """ + transform = parameters.base.get_transform_method(f"{mock_name}.{extension}", "auto") + + assert transform == extension + + +@pytest.mark.parametrize("key", ["json", "binary", "example", "example.jsonp"]) +def test_get_transform_method_preserve_auto_unhandled(key): + """ + Check if any key that does not end with a supported extension returns None when + using the transform="auto" + """ + transform = parameters.base.get_transform_method(key, "auto") + + assert transform is None diff --git a/tests/functional/test_utilities_typing.py b/tests/functional/test_utilities_typing.py new file mode 100644 index 00000000000..8522cfcbf99 --- /dev/null +++ b/tests/functional/test_utilities_typing.py @@ -0,0 +1,51 @@ +from aws_lambda_powertools.utilities.typing import LambdaContext +from aws_lambda_powertools.utilities.typing.lambda_client_context import LambdaClientContext +from aws_lambda_powertools.utilities.typing.lambda_client_context_mobile_client import LambdaClientContextMobileClient +from aws_lambda_powertools.utilities.typing.lambda_cognito_identity import LambdaCognitoIdentity + + +def test_typing(): + context = LambdaContext() + context._function_name = "_function_name" + context._function_version = "_function_version" + context._invoked_function_arn = "_invoked_function_arn" + context._memory_limit_in_mb = "_memory_limit_in_mb" + context._aws_request_id = "_aws_request_id" + context._log_group_name = "_log_group_name" + context._log_stream_name = "_log_stream_name" + identity = LambdaCognitoIdentity() + identity._cognito_identity_id = "_cognito_identity_id" + identity._cognito_identity_pool_id = "_cognito_identity_pool_id" + context._identity = identity + client_context = LambdaClientContext() + client = LambdaClientContextMobileClient() + client._installation_id = "_installation_id" + client._app_title = "_app_title" + client._app_version_name = "_app_version_name" + client._app_version_code = "_app_version_code" + client._app_package_name = "_app_package_name" + client_context._client = client + client_context._custom = {} + client_context._env = {} + context._client_context = client_context + + assert context.function_name == context._function_name + assert context.function_version == context._function_version + assert context.invoked_function_arn == context._invoked_function_arn + assert context.memory_limit_in_mb == context._memory_limit_in_mb + assert context.aws_request_id == context._aws_request_id + assert context.log_group_name == context._log_group_name + assert context.log_stream_name == context._log_stream_name + assert context.identity == context._identity + assert context.identity.cognito_identity_id == identity._cognito_identity_id + assert context.identity.cognito_identity_pool_id == identity._cognito_identity_pool_id + assert context.client_context == context._client_context + assert context.client_context.client == client_context._client + assert context.client_context.client.installation_id == client._installation_id + assert context.client_context.client.app_title == client._app_title + assert context.client_context.client.app_version_name == client._app_version_name + assert context.client_context.client.app_version_code == client._app_version_code + assert context.client_context.client.app_package_name == client._app_package_name + assert context.client_context.custom == client_context._custom + assert context.client_context.env == client_context._env + assert context.get_remaining_time_in_millis() == 0 diff --git a/tests/unit/test_utilities_batch.py b/tests/unit/test_utilities_batch.py new file mode 100644 index 00000000000..136e6ff2e8c --- /dev/null +++ b/tests/unit/test_utilities_batch.py @@ -0,0 +1,135 @@ +import pytest +from botocore.config import Config + +from aws_lambda_powertools.utilities.batch import PartialSQSProcessor +from aws_lambda_powertools.utilities.batch.exceptions import SQSBatchProcessingError + + +@pytest.fixture(scope="function") +def sqs_event(): + return { + "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", + "body": "", + "attributes": {}, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-1", + } + + +@pytest.fixture(scope="module") +def config() -> Config: + return Config(region_name="us-east-1") + + +@pytest.fixture(scope="function") +def partial_sqs_processor(config) -> PartialSQSProcessor: + return PartialSQSProcessor(config=config) + + +def test_partial_sqs_get_queue_url_with_records(mocker, sqs_event, partial_sqs_processor): + expected_url = "https://queue.amazonaws.com/123456789012/my-queue" + + records_mock = mocker.patch.object(PartialSQSProcessor, "records", create=True, new_callable=mocker.PropertyMock) + records_mock.return_value = [sqs_event] + + result = partial_sqs_processor._get_queue_url() + assert result == expected_url + + +def test_partial_sqs_get_queue_url_without_records(partial_sqs_processor): + assert partial_sqs_processor._get_queue_url() is None + + +def test_partial_sqs_get_entries_to_clean_with_success(mocker, sqs_event, partial_sqs_processor): + expected_entries = [{"Id": sqs_event["messageId"], "ReceiptHandle": sqs_event["receiptHandle"]}] + + success_messages_mock = mocker.patch.object( + PartialSQSProcessor, "success_messages", create=True, new_callable=mocker.PropertyMock + ) + success_messages_mock.return_value = [sqs_event] + + result = partial_sqs_processor._get_entries_to_clean() + + assert result == expected_entries + + +def test_partial_sqs_get_entries_to_clean_without_success(mocker, partial_sqs_processor): + expected_entries = [] + + success_messages_mock = mocker.patch.object( + PartialSQSProcessor, "success_messages", create=True, new_callable=mocker.PropertyMock + ) + success_messages_mock.return_value = [] + + result = partial_sqs_processor._get_entries_to_clean() + + assert result == expected_entries + + +def test_partial_sqs_process_record_success(mocker, partial_sqs_processor): + expected_value = mocker.sentinel.expected_value + + success_result = mocker.sentinel.success_result + record = mocker.sentinel.record + + handler_mock = mocker.patch.object(PartialSQSProcessor, "handler", create=True, return_value=success_result) + success_handler_mock = mocker.patch.object(PartialSQSProcessor, "success_handler", return_value=expected_value) + + result = partial_sqs_processor._process_record(record) + + handler_mock.assert_called_once_with(record) + success_handler_mock.assert_called_once_with(record, success_result) + + assert result == expected_value + + +def test_partial_sqs_process_record_failure(mocker, partial_sqs_processor): + expected_value = mocker.sentinel.expected_value + + failure_result = Exception() + record = mocker.sentinel.record + + handler_mock = mocker.patch.object(PartialSQSProcessor, "handler", create=True, side_effect=failure_result) + failure_handler_mock = mocker.patch.object(PartialSQSProcessor, "failure_handler", return_value=expected_value) + + result = partial_sqs_processor._process_record(record) + + handler_mock.assert_called_once_with(record) + failure_handler_mock.assert_called_once_with(record, failure_result) + + assert result == expected_value + + +def test_partial_sqs_prepare(mocker, partial_sqs_processor): + success_messages_mock = mocker.patch.object(partial_sqs_processor, "success_messages", spec=list) + failed_messages_mock = mocker.patch.object(partial_sqs_processor, "fail_messages", spec=list) + + partial_sqs_processor._prepare() + + success_messages_mock.clear.assert_called_once() + failed_messages_mock.clear.assert_called_once() + + +def test_partial_sqs_clean(monkeypatch, mocker, partial_sqs_processor): + records = [mocker.sentinel.record] + + monkeypatch.setattr(partial_sqs_processor, "fail_messages", records) + monkeypatch.setattr(partial_sqs_processor, "success_messages", records) + + queue_url_mock = mocker.patch.object(PartialSQSProcessor, "_get_queue_url") + entries_to_clean_mock = mocker.patch.object(PartialSQSProcessor, "_get_entries_to_clean") + + queue_url_mock.return_value = mocker.sentinel.queue_url + entries_to_clean_mock.return_value = mocker.sentinel.entries_to_clean + + client_mock = mocker.patch.object(partial_sqs_processor, "client", autospec=True) + with pytest.raises(SQSBatchProcessingError): + partial_sqs_processor._clean() + + client_mock.delete_message_batch.assert_called_once_with( + QueueUrl=mocker.sentinel.queue_url, Entries=mocker.sentinel.entries_to_clean + )