From 3603b0ef584d4c1eeb2dee12fa6656494a721b4e Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Fri, 12 Nov 2021 16:39:05 +0000 Subject: [PATCH 1/9] Add a subscribe_temporary() function to common transport API --- src/workflows/transport/common_transport.py | 51 +++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/workflows/transport/common_transport.py b/src/workflows/transport/common_transport.py index d7fd6e7a..42fce5e0 100644 --- a/src/workflows/transport/common_transport.py +++ b/src/workflows/transport/common_transport.py @@ -8,6 +8,13 @@ MessageCallback = Callable[[Mapping[str, Any], Any], None] +from typing import NamedTuple + + +class TemporarySubscription(NamedTuple): + subscription_id: int + queue_name: str + class CommonTransport: """A common transport class, containing e.g. the logic to manage @@ -78,6 +85,50 @@ def mangled_callback(header, message): self._subscribe(self.__subscription_id, channel, mangled_callback, **kwargs) return self.__subscription_id + def subscribe_temporary( + self, channel_hint: Optional[str], callback: MessageCallback, **kwargs + ) -> TemporarySubscription: + """Listen to a new queue that is specifically created for this connection, + and has a limited lifetime. Notify for messages via callback function. + :param channel_hint: Suggested queue name to subscribe to, the actual + queue name will be decided by both transport layer + and server. + :param callback: Function to be called when messages are received. + The callback will pass two arguments, the header as a + dictionary structure, and the message. + :param **kwargs: Further parameters for the transport layer. For example + disable_mangling: Receive messages as unprocessed strings. + acknowledgement: If true receipt of each message needs to be + acknowledged. + :return: A named tuple containing a unique subscription ID and the actual + queue name which can then be referenced by other senders. + """ + self.__subscription_id += 1 + + def mangled_callback(header: Mapping[str, Any], message: Any, /) -> None: + callback(header, self._mangle_for_receiving(message)) + + if "disable_mangling" in kwargs: + if kwargs["disable_mangling"]: + mangled_callback = callback # noqa:F811 + del kwargs["disable_mangling"] + self.__subscriptions[self.__subscription_id] = { + # "channel": channel, + "callback": mangled_callback, + "ack": kwargs.get("acknowledgement"), + "unsubscribed": False, + } + self.log.debug( + "Subscribing to temporary queue (name hint: %r) with ID %d", + channel_hint, + self.__subscription_id, + ) + # self._subscribe(self.__subscription_id, channel, mangled_callback, **kwargs) + + return TemporarySubscription( + subscription_id=self.__subscription_id, queue_name="" + ) + def unsubscribe(self, subscription: int, drop_callback_reference=False, **kwargs): """Stop listening to a queue or a broadcast :param subscription: Subscription ID to cancel From 2832950d08b7f6a74934d0534bde499d53ff8b2b Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Mon, 15 Nov 2021 08:20:31 +0000 Subject: [PATCH 2/9] Add transport interface API call --- src/workflows/transport/common_transport.py | 24 +++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/workflows/transport/common_transport.py b/src/workflows/transport/common_transport.py index 42fce5e0..15e4d453 100644 --- a/src/workflows/transport/common_transport.py +++ b/src/workflows/transport/common_transport.py @@ -123,10 +123,12 @@ def mangled_callback(header: Mapping[str, Any], message: Any, /) -> None: channel_hint, self.__subscription_id, ) - # self._subscribe(self.__subscription_id, channel, mangled_callback, **kwargs) + queue_name = self._subscribe_temporary( + self.__subscription_id, channel_hint, mangled_callback, **kwargs + ) return TemporarySubscription( - subscription_id=self.__subscription_id, queue_name="" + subscription_id=self.__subscription_id, queue_name=queue_name ) def unsubscribe(self, subscription: int, drop_callback_reference=False, **kwargs): @@ -413,6 +415,24 @@ def _subscribe_broadcast(self, sub_id: int, channel, callback, **kwargs): """ raise NotImplementedError("Transport interface not implemented") + def _subscribe_temporary( + self, + sub_id: int, + channel_hint: Optional[str], + callback: MessageCallback, + **kwargs, + ) -> str: + """Create and then listen to a temporary queue, notify via callback function. + :param sub_id: ID for this subscription in the transport layer + :param channel_hint: Name suggestion for the temporary queue + :param callback: Function to be called when messages are received + :param **kwargs: Further parameters for the transport layer. For example + acknowledgement: If true receipt of each message needs to be + acknowledged. + :returns: The name of the temporary queue + """ + raise NotImplementedError("Transport interface not implemented") + def _unsubscribe(self, sub_id: int, **kwargs): """Stop listening to a queue or a broadcast :param sub_id: ID for this subscription in the transport layer From 2f06fcbd9b3210578e170fd8d0638b3990f15ede Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Mon, 15 Nov 2021 11:08:15 +0000 Subject: [PATCH 3/9] Pass with Python 3.7-compatible syntax --- src/workflows/transport/common_transport.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/workflows/transport/common_transport.py b/src/workflows/transport/common_transport.py index 15e4d453..a1d85be2 100644 --- a/src/workflows/transport/common_transport.py +++ b/src/workflows/transport/common_transport.py @@ -105,9 +105,11 @@ def subscribe_temporary( """ self.__subscription_id += 1 - def mangled_callback(header: Mapping[str, Any], message: Any, /) -> None: + def _(header: Mapping[str, Any], message: Any) -> None: callback(header, self._mangle_for_receiving(message)) + mangled_callback: MessageCallback = _ + if "disable_mangling" in kwargs: if kwargs["disable_mangling"]: mangled_callback = callback # noqa:F811 From 223cc200df941fd82ec8aaa93d1dc614437779f4 Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Mon, 15 Nov 2021 11:13:30 +0000 Subject: [PATCH 4/9] Tidy up imports --- src/workflows/transport/common_transport.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/workflows/transport/common_transport.py b/src/workflows/transport/common_transport.py index a1d85be2..120c3665 100644 --- a/src/workflows/transport/common_transport.py +++ b/src/workflows/transport/common_transport.py @@ -2,14 +2,12 @@ import decimal import logging -from typing import Any, Callable, Dict, Mapping, Optional, Set +from typing import Any, Callable, Dict, Mapping, NamedTuple, Optional, Set import workflows MessageCallback = Callable[[Mapping[str, Any], Any], None] -from typing import NamedTuple - class TemporarySubscription(NamedTuple): subscription_id: int From 5a37ad74c72d6efc770778bea0ef6ad6410c2851 Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Tue, 16 Nov 2021 08:22:03 +0000 Subject: [PATCH 5/9] Implement StompTransport._subscribe_temporary --- src/workflows/transport/stomp_transport.py | 36 ++++++++++++++++++++-- src/workflows/util/__init__.py | 2 +- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/workflows/transport/stomp_transport.py b/src/workflows/transport/stomp_transport.py index c484fb60..aa9a9973 100644 --- a/src/workflows/transport/stomp_transport.py +++ b/src/workflows/transport/stomp_transport.py @@ -2,12 +2,17 @@ import json import threading import time -from typing import Any, Dict +import uuid +from typing import Any, Dict, Optional import stomp -import workflows -from workflows.transport.common_transport import CommonTransport, json_serializer +import workflows.util +from workflows.transport.common_transport import ( + CommonTransport, + MessageCallback, + json_serializer, +) class StompTransport(CommonTransport): @@ -334,6 +339,31 @@ def _subscribe_broadcast(self, sub_id, channel, callback, **kwargs): headers["activemq.retroactive"] = "true" self._conn.subscribe(destination, sub_id, headers=headers) + def _subscribe_temporary( + self, + sub_id: int, + channel_hint: Optional[str], + callback: MessageCallback, + **kwargs, + ) -> str: + """Create and then listen to a temporary queue, notify via callback function. + :param sub_id: ID for this subscription in the transport layer + :param channel_hint: Name suggestion for the temporary queue + :param callback: Function to be called when messages are received + :param **kwargs: Further parameters for the transport layer. + See _subscribe() above. + :returns: The name of the temporary queue + """ + + channel = channel_hint or workflows.util.generate_unique_host_id() + channel = channel + "." + str(uuid.uuid4()) + if not channel.startswith("transient."): + channel = "transient." + channel + + self._subscribe(sub_id, channel, callback, **kwargs) + + return channel + def _unsubscribe(self, subscription, **kwargs): """Stop listening to a queue or a broadcast :param subscription: Subscription ID to cancel diff --git a/src/workflows/util/__init__.py b/src/workflows/util/__init__.py index c0593d53..709025d0 100644 --- a/src/workflows/util/__init__.py +++ b/src/workflows/util/__init__.py @@ -2,7 +2,7 @@ import socket -def generate_unique_host_id(): +def generate_unique_host_id() -> str: """Generate a unique ID, that is somewhat guaranteed to be unique among all instances running at the same time.""" host = ".".join(reversed(socket.gethostname().split("."))) From 001fb3d56c13589f50db8335ba5e65b456f051f7 Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Tue, 16 Nov 2021 12:39:47 +0000 Subject: [PATCH 6/9] Implement PikaTransport._subscribe_temporary --- src/workflows/transport/pika_transport.py | 107 +++++++++++++++++++++- tests/transport/test_pika.py | 63 ++++++++++++- 2 files changed, 164 insertions(+), 6 deletions(-) diff --git a/src/workflows/transport/pika_transport.py b/src/workflows/transport/pika_transport.py index b073205e..d4e991bf 100644 --- a/src/workflows/transport/pika_transport.py +++ b/src/workflows/transport/pika_transport.py @@ -9,15 +9,15 @@ import sys import threading import time +import uuid from concurrent.futures import Future from enum import Enum, auto from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -import pika import pika.exceptions from pika.adapters.blocking_connection import BlockingChannel -import workflows +import workflows.util from workflows.transport.common_transport import ( CommonTransport, MessageCallback, @@ -414,8 +414,7 @@ def _subscribe_broadcast( Args: sub_id: Internal ID for this subscription channel: Name of the exchange to bind to - callback: - Function to be called when message are received + callback: Function to be called when messages are received reconnectable: Can we reconnect to this exchange if the connection is lost. Currently, this means that messages can be missed @@ -429,6 +428,47 @@ def _subscribe_broadcast( reconnectable=reconnectable, ).result() + def _subscribe_temporary( + self, + sub_id: int, + channel_hint: Optional[str], + callback: MessageCallback, + *, + acknowledgement: bool = False, + **kwargs, + ) -> str: + """ + Create and then listen to a temporary queue, notify via callback function. + + Wait until subscription is complete. + + Args: + sub_id: Internal ID for this subscription + channel_hint: Name suggestion for the temporary queue + callback: Function to be called when messages are received + acknowledgement: + Each message will need to be explicitly acknowledged. + Returns: + The name of the temporary queue + """ + queue: str = channel_hint or "" + if queue and not queue.startswith("transient."): + queue = "transient." + queue + queue = queue + "." + str(uuid.uuid4()) + + try: + return self._pika_thread.subscribe_temporary( + queue=queue, + callback=functools.partial(self._call_message_callback, sub_id), + auto_ack=not acknowledgement, + subscription_id=sub_id, + ).result() + except ( + pika.exceptions.AMQPChannelError, + pika.exceptions.AMQPConnectionError, + ) as e: + raise workflows.Disconnected(e) + def _unsubscribe(self, sub_id: int, **kwargs): """Stop listening to a queue :param sub_id: Consumer Tag to cancel @@ -897,6 +937,65 @@ def subscribe_broadcast( ) return result + def subscribe_temporary( + self, + queue: str, + callback: PikaCallback, + subscription_id: int, + *, + auto_ack: bool = True, + ) -> Future[str]: + """ + Create and then subscribe to a temporary queue. Thread-safe. + + Args: + queue: The queue to listen for messages on, may be empty + callback: The function to call when receiving messages on this queue + subscription_id: Internal ID representing this subscription. + auto_ack: Should this subscription auto-acknowledge messages? + + Returns: + A Future representing the resulting queue name. + It will be set upon subscription success. + """ + + if not self._connection: + raise RuntimeError("Cannot subscribe to unstarted connection") + + result: Future[str] = Future() + + def _declare_subscribe_queue_in_thread(): + try: + if result.set_running_or_notify_cancel(): + assert ( + subscription_id not in self._subscriptions + ), f"Subscription request {subscription_id} rejected due to existing subscription {self._subscriptions[subscription_id]}" + temporary_queue_name = ( + self._get_shared_channel() + .queue_declare( + queue, auto_delete=True, exclusive=True, durable=False + ) + .method.queue + ) + temporary_subscription = _PikaSubscription( + arguments={}, + auto_ack=auto_ack, + destination=temporary_queue_name, + kind=_PikaSubscriptionKind.DIRECT, + on_message_callback=callback, + prefetch_count=0, + reconnectable=False, + ) + self._add_subscription(subscription_id, temporary_subscription) + result.set_result(temporary_queue_name) + except BaseException as e: + result.set_exception(e) + raise + + self._connection.add_callback_threadsafe(_declare_subscribe_queue_in_thread) + + return result + def unsubscribe(self, subscription_id: int) -> Future[None]: if subscription_id not in self._subscriptions: raise KeyError( diff --git a/tests/transport/test_pika.py b/tests/transport/test_pika.py index c5cc8076..dfe91850 100644 --- a/tests/transport/test_pika.py +++ b/tests/transport/test_pika.py @@ -12,9 +12,8 @@ import pika import pytest -import workflows -import workflows.transport import workflows.transport.pika_transport +from workflows.transport.common_transport import TemporarySubscription from workflows.transport.pika_transport import PikaTransport, _PikaThread @@ -47,6 +46,14 @@ def revert_classvariables(): PikaTransport.defaults = defaults +@pytest.fixture +def pikatransport(revert_classvariables): + pt = PikaTransport() + pt.connect() + yield pt + pt.disconnect() + + def test_lookup_and_initialize_pika_transport_layer(): """Find the pika transport layer via the lookup mechanism and run its constructor with default settings @@ -1088,3 +1095,55 @@ def _get_message(*args): def test_pikathread_ack(): pytest.xfail("Not Implemented") + + +def test_full_stack_temporary_queue_roundtrip(pikatransport): + known_subscriptions = set() + known_queues = set() + + def assert_not_seen_before(ts: TemporarySubscription): + assert ts.subscription_id, "Temporary subscription is missing an ID" + assert ( + ts.subscription_id not in known_subscriptions + ), "Duplicate subscription ID" + assert ts.queue_name, "Temporary queue does not have a name" + assert ts.queue_name not in known_queues, "Duplicate temporary queue name" + known_subscriptions.add(ts.subscription_id) + known_queues.add(ts.queue_name) + print(f"Temporary subscription: {ts}") + + replies = Queue() + + def callback(subscription): + def _callback(header, message): + print(f"Received message for {subscription}: {message}") + replies.put((subscription, header, message)) + + return _callback + + ts = {} + for n, queue_hint in enumerate( + ("", "", "hint", "hint", "transient.hint", "transient.hint") + ): + ts[n] = pikatransport.subscribe_temporary(queue_hint, callback(n)) + assert_not_seen_before(ts[n]) + assert queue_hint in ts[n].queue_name + assert "transient.transient." not in ts[n].queue_name + + assert replies.empty() + + outstanding_messages = set() + for n in range(6): + outstanding_messages.add((n, f"message {n}")) + pikatransport.send(ts[n].queue_name, f"message {n}") + + try: + while outstanding_messages: + s, _, m = replies.get(timeout=1.5) + if (s, m) not in outstanding_messages: + raise RuntimeError( + f"Received unexpected message {m} on subscription {s}" + ) + outstanding_messages.remove((s, m)) + except Empty: + raise RuntimeError(f"Missing replies for {len(outstanding_messages)} messages") From ce4ca5acef3572c9f06dd2f14241263787ba6db6 Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Tue, 16 Nov 2021 14:42:00 +0000 Subject: [PATCH 7/9] Skip full stack tests without RabbitMQ server --- tests/transport/test_pika.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/transport/test_pika.py b/tests/transport/test_pika.py index dfe91850..d99ab1d3 100644 --- a/tests/transport/test_pika.py +++ b/tests/transport/test_pika.py @@ -47,7 +47,10 @@ def revert_classvariables(): @pytest.fixture -def pikatransport(revert_classvariables): +def pikatransport(revert_classvariables, connection_params): + # connection_params is unused here, but implements the fixture skipping + # logic following a single test, instead of attempting a connection for + # every individual test. pt = PikaTransport() pt.connect() yield pt From 9208ca0bc8942b8ce4bdf4c9597f53d8843616f2 Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Wed, 17 Nov 2021 10:28:09 +0000 Subject: [PATCH 8/9] Pika: Use server-assigned queue name if no hint --- src/workflows/transport/pika_transport.py | 7 ++++--- tests/transport/test_pika.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/workflows/transport/pika_transport.py b/src/workflows/transport/pika_transport.py index d4e991bf..ba65e74c 100644 --- a/src/workflows/transport/pika_transport.py +++ b/src/workflows/transport/pika_transport.py @@ -452,9 +452,10 @@ def _subscribe_temporary( The name of the temporary queue """ queue: str = channel_hint or "" - if queue and not queue.startswith("transient."): - queue = "transient." + queue - queue = queue + "." + str(uuid.uuid4()) + if queue: + if not queue.startswith("transient."): + queue = "transient." + queue + queue = queue + "." + str(uuid.uuid4()) try: return self._pika_thread.subscribe_temporary( diff --git a/tests/transport/test_pika.py b/tests/transport/test_pika.py index d99ab1d3..61c3480f 100644 --- a/tests/transport/test_pika.py +++ b/tests/transport/test_pika.py @@ -1132,6 +1132,8 @@ def _callback(header, message): assert_not_seen_before(ts[n]) assert queue_hint in ts[n].queue_name assert "transient.transient." not in ts[n].queue_name + if not queue_hint: + assert ts[n].queue_name.startswith("amq.gen-") assert replies.empty() From 1aa1d8100798c9f3488da29f19e9febe5f8b80eb Mon Sep 17 00:00:00 2001 From: Richard Gildea Date: Wed, 17 Nov 2021 12:20:45 +0000 Subject: [PATCH 9/9] StompTransport: test_subscribe_to_temporary_queue --- tests/transport/test_stomp.py | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/transport/test_stomp.py b/tests/transport/test_stomp.py index dbdd9d4a..552d84d6 100644 --- a/tests/transport/test_stomp.py +++ b/tests/transport/test_stomp.py @@ -14,6 +14,7 @@ import workflows import workflows.transport +from workflows.transport.common_transport import TemporarySubscription from workflows.transport.stomp_transport import StompTransport _frame = namedtuple("frame", "headers, body") @@ -586,6 +587,45 @@ def callback_resolver(cbid): mockconn.unsubscribe.assert_called_with(id=2) +@mock.patch("workflows.transport.stomp_transport.stomp") +def test_subscribe_to_temporary_queue(mockstomp): + """Test subscribing to a topic (publish-subscribe) and callback functions.""" + mock_cb = mock.Mock() + stomp = StompTransport() + stomp.connect() + mockconn = mockstomp.Connection.return_value + + known_subscriptions = set() + known_queues = set() + + def assert_not_seen_before(ts: TemporarySubscription): + assert ts.subscription_id, "Temporary subscription is missing an ID" + assert ( + ts.subscription_id not in known_subscriptions + ), "Duplicate subscription ID" + assert ts.queue_name, "Temporary queue does not have a name" + assert ts.queue_name not in known_queues, "Duplicate temporary queue name" + known_subscriptions.add(ts.subscription_id) + known_queues.add(ts.queue_name) + print(f"Temporary subscription: {ts}") + + mockconn.set_listener.assert_called_once() + listener = mockconn.set_listener.call_args[0][1] + assert listener is not None + + ts = {} + for n, queue_hint in enumerate( + ("", "", "hint", "hint", "transient.hint", "transient.hint") + ): + ts[n] = stomp.subscribe_temporary( + channel_hint=queue_hint, + callback=mock_cb, + ) + assert_not_seen_before(ts[n]) + assert ts[n].queue_name.startswith("transient.") + return + + @mock.patch("workflows.transport.stomp_transport.stomp") def test_transaction_calls(mockstomp): """Test that calls to create, commit, abort transactions are passed to stomp properly."""