From a3f97e24e6f06fff2d1ca20ba627ca6c66854607 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D0=BD=D0=BD=D0=B0=20=D0=A9=D0=B5=D1=80=D0=B5=D0=BD?= =?UTF-8?q?=D0=BA=D0=BE?= Date: Wed, 18 Jun 2025 11:20:20 +0300 Subject: [PATCH 1/2] Add OAuth Token Manager class --- shapeshifter_uftp/client/base_client.py | 17 ++- shapeshifter_uftp/service/base_service.py | 25 ++++- shapeshifter_uftp/token_manager.py | 122 ++++++++++++++++++++++ test/helpers/services.py | 2 +- 4 files changed, 162 insertions(+), 4 deletions(-) create mode 100644 shapeshifter_uftp/token_manager.py diff --git a/shapeshifter_uftp/client/base_client.py b/shapeshifter_uftp/client/base_client.py index a09c367..31a7542 100644 --- a/shapeshifter_uftp/client/base_client.py +++ b/shapeshifter_uftp/client/base_client.py @@ -11,6 +11,7 @@ from ..exceptions import ClientTransportException from ..logging import logger from ..uftp import PayloadMessage, PayloadMessageResponse, SignedMessage +from shapeshifter_uftp.token_manager import AuthTokenManager class ShapeshifterClient: @@ -34,6 +35,7 @@ def __init__( recipient_domain: str, recipient_endpoint: str = None, recipient_signing_key: str = None, + oauth_token_manager: AuthTokenManager = None, ): """ Shapeshifter client class that allows you to initiate messages to a different party. @@ -55,6 +57,7 @@ def __init__( self.recipient_domain = recipient_domain self.recipient_endpoint = recipient_endpoint self.recipient_signing_key = recipient_signing_key + self.oauth_token_manager = oauth_token_manager # The outgoing queue and scheduler are used when queueing # messages for delivery later. This allows the Shapeshifter @@ -114,11 +117,23 @@ def _send_message(self, message: PayloadMessage) -> PayloadMessageResponse: logger.debug(f"Sending message to {self.recipient_endpoint}:") logger.debug(serialized_message) + # Find the right headers to use for the request. If we have + # an OAuth2 token manager, we will use that to get the + # request headers. If not, we will use the basic Content-Type + try: + if self.oauth_token_manager: + header = self.oauth_token_manager.get_request_headers() + else: + header = {"Content-Type": "text/xml; charset=utf-8"} + except Exception as e: + logger.warning(f"Failed to get OAuth2 headers, falling back to basic headers: {e}") + header = {"Content-Type": "text/xml; charset=utf-8"} + # Send the request to the relevant endpoint response = requests.post( self.recipient_endpoint, data=serialized_message, - headers={"Content-Type": "text/xml; charset=utf-8"}, + headers=header, timeout=self.request_timeout, ) if response.status_code != 200: diff --git a/shapeshifter_uftp/service/base_service.py b/shapeshifter_uftp/service/base_service.py index 30c3d01..d6e7634 100644 --- a/shapeshifter_uftp/service/base_service.py +++ b/shapeshifter_uftp/service/base_service.py @@ -25,6 +25,7 @@ PayloadMessageResponse, SignedMessage, ) +from ..token_manager import AuthTokenManager class ShapeshifterService(): @@ -46,11 +47,15 @@ def __init__( self, sender_domain, signing_key, + oauth_token_endpoint: str = None, + oauth_client_id: str = None, + oauth_client_secret: str = None, + token_refresh_buffer: int = 30, key_lookup_function=None, endpoint_lookup_function=None, host: str = "0.0.0.0", port: int = 8080, - path="/shapeshifter/api/v3/message", + path="/shapeshifter/api/v3/message" ): """ :param sender_domain: our sender domain (FQDN) that the recipient uses to look us up. @@ -64,6 +69,9 @@ def __init__( :param host: the host to bind the server to (usually 127.0.0.1 or 0.0.0.0) :param port: the port to bind the server to (default: 8080) :param path: the URL path that the server listens on (default: /shapeshifter/api/v3/message) + :param oauth_token_endpoint: the OAuth2 token endpoint to use for obtaining access tokens + :param oauth_client_id: the OAuth2 client ID to use for obtaining access tokens + :param oauth_client_secret: the OAuth2 client secret to use for obtaining access tokens """ # Set the sender domain, which is used @@ -87,6 +95,18 @@ def __init__( # The FastAPI web app takes care of routing messages to the # (one) endpoint, and by virtue of FastAPI-XML convert the # python-friendly objects into XML and vice versa. + + # Create Auth Manager for OAuth2 Client Credentials flow (if configured) + if oauth_token_endpoint and oauth_client_id and oauth_client_secret: + self.auth_token_manager = AuthTokenManager( + oauth_token_endpoint=oauth_token_endpoint, + oauth_client_id=oauth_client_id, + oauth_client_secret=oauth_client_secret, + token_refresh_buffer=token_refresh_buffer + ) + else: + self.auth_token_manager = None + self.app = FastAPI(default_response_class=XmlAppResponse) self.app.router.route_class = XmlRoute self.app.router.add_api_route( @@ -249,7 +269,8 @@ def _get_client(self, recipient_domain, recipient_role): signing_key = self.signing_key, recipient_domain = recipient_domain, recipient_endpoint = recipient_endpoint, - recipient_signing_key = recipient_signing_key + recipient_signing_key = recipient_signing_key, + oauth_token_manager = self.auth_token_manager ) def __enter__(self): diff --git a/shapeshifter_uftp/token_manager.py b/shapeshifter_uftp/token_manager.py new file mode 100644 index 0000000..04e4427 --- /dev/null +++ b/shapeshifter_uftp/token_manager.py @@ -0,0 +1,122 @@ +from datetime import datetime, timezone, timedelta + +import requests + +from .logging import logger + +from typing import Optional +from threading import Lock + +class AuthTokenManager: + """ + A token manager that can be used to manage tokens for the Shapeshifter client. + It handles OAuth2 Client Credentials flow to obtain and refresh tokens. + This class is thread-safe and ensures that tokens are refreshed only when necessary. + It provides a method to get request headers with the Bearer token included. + """ + request_timeout: int = 30 + + def __init__(self, + oauth_token_endpoint: str, + oauth_client_id: str, + oauth_client_secret: str, + token_refresh_buffer: int = 30): + self.oauth_token_endpoint = oauth_token_endpoint + self.oauth_client_id = oauth_client_id + self.oauth_client_secret = oauth_client_secret + self.token_refresh_buffer = token_refresh_buffer + self._access_token: Optional[str] = None + self._token_expires_at: Optional[datetime] = None + self._token_lock = Lock() + + def _is_oauth_configured(self) -> bool: + """Check if OAuth2 is properly configured.""" + return all([ + self.oauth_token_endpoint, + self.oauth_client_id, + self.oauth_client_secret + ]) + + def _is_token_valid(self) -> bool: + """Check if the current token is valid and not close to expiring.""" + if not self._access_token or not self._token_expires_at: + return False + + buffer_time = datetime.now(timezone.utc) + timedelta(seconds=self.token_refresh_buffer) + return self._token_expires_at > buffer_time + + def _obtain_bearer_token(self) -> str: + """ + Obtain a Bearer token using OAuth2 Client Credentials flow. + + :return: Access token string + :raises: Exception if token acquisition fails + """ + if not self._is_oauth_configured(): + raise ValueError("OAuth2 not configured. Please provide oauth_token_endpoint, oauth_client_id, and oauth_client_secret.") + + token_data = { + 'grant_type': 'client_credentials', + 'client_id': self.oauth_client_id, + 'client_secret': self.oauth_client_secret + } + + headers = { + 'Content-Type': 'application/x-www-form-urlencoded' + } + + try: + response = requests.post( + self.oauth_token_endpoint, + data=token_data, + headers=headers, + timeout=self.request_timeout + ) + response.raise_for_status() + + token_response = response.json() + access_token = token_response['access_token'] + expires_in = token_response.get('expires_in', 300) + + self._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + + logger.info(f"Successfully obtained OAuth2 token, expires at {self._token_expires_at}") + return access_token + + except requests.exceptions.RequestException as e: + logger.error(f"Failed to obtain OAuth2 token: {e}") + raise + except KeyError as e: + logger.error(f"Invalid token response format, missing key: {e}") + raise + + def _get_valid_token(self) -> Optional[str]: + """ + Get a valid Bearer token, refreshing if necessary. + Thread-safe implementation. + + :return: Valid access token or None if OAuth2 not configured + """ + if not self._is_oauth_configured(): + return None + + with self._token_lock: + if not self._is_token_valid(): + logger.debug("Token invalid or expired, obtaining new token") + self._access_token = self._obtain_bearer_token() + + return self._access_token + + def get_request_headers(self) -> dict: + """ + Get headers for HTTP requests, including Bearer token if configured. + + :return: Dictionary of headers + """ + headers = {"Content-Type": "text/xml; charset=utf-8"} + + token = self._get_valid_token() + if token: + headers["Authorization"] = f"Bearer {token}" + + return headers \ No newline at end of file diff --git a/test/helpers/services.py b/test/helpers/services.py index dea0598..72a7929 100644 --- a/test/helpers/services.py +++ b/test/helpers/services.py @@ -36,7 +36,7 @@ def key_lookup_function(domain, role): return CRO_PUBLIC_KEY elif domain == "dso.dev": return DSO_PUBLIC_KEY - + class DummyAgrService(ShapeshifterAgrService): From 812c0486a2abb4dddfd641310a63c73fc55fd1db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D0=BD=D0=BD=D0=B0=20=D0=A9=D0=B5=D1=80=D0=B5=D0=BD?= =?UTF-8?q?=D0=BA=D0=BE?= Date: Wed, 18 Jun 2025 11:24:35 +0300 Subject: [PATCH 2/2] Add unit tests for auth --- test/test_oauth_token_unit.py | 472 ++++++++++++++++++++++++++++++++++ 1 file changed, 472 insertions(+) create mode 100644 test/test_oauth_token_unit.py diff --git a/test/test_oauth_token_unit.py b/test/test_oauth_token_unit.py new file mode 100644 index 0000000..41f6d43 --- /dev/null +++ b/test/test_oauth_token_unit.py @@ -0,0 +1,472 @@ +import pytest +import requests +from unittest.mock import Mock, patch, MagicMock +from datetime import datetime, timezone, timedelta +import json +import time + +from shapeshifter_uftp.token_manager import AuthTokenManager +from shapeshifter_uftp.service.base_service import ShapeshifterService +from shapeshifter_uftp.client.base_client import ShapeshifterClient +from shapeshifter_uftp.uftp import PayloadMessage + + +class TestAuthTokenManager: + """Test suite for AuthTokenManager class""" + + @pytest.fixture + def token_manager(self): + """Fixture for AuthTokenManager instance""" + return AuthTokenManager( + oauth_token_endpoint="https://test.example.com/oauth2/token", + oauth_client_id="test_client_id", + oauth_client_secret="test_client_secret", + token_refresh_buffer=30 + ) + + def test_init(self, token_manager): + """Test AuthTokenManager initialization""" + assert token_manager.oauth_token_endpoint == "https://test.example.com/oauth2/token" + assert token_manager.oauth_client_id == "test_client_id" + assert token_manager.oauth_client_secret == "test_client_secret" + assert token_manager.token_refresh_buffer == 30 + assert token_manager._access_token is None + assert token_manager._token_expires_at is None + + def test_is_oauth_configured(self, token_manager): + """Test OAuth2 configuration check""" + assert token_manager._is_oauth_configured() is True + + # Test with missing configuration + incomplete_manager = AuthTokenManager("", "", "") + assert incomplete_manager._is_oauth_configured() is False + + def test_is_token_valid_no_token(self, token_manager): + """Test token validity when no token exists""" + assert token_manager._is_token_valid() is False + + def test_is_token_valid_expired(self, token_manager): + """Test token validity when token is expired""" + token_manager._access_token = "test_token" + token_manager._token_expires_at = datetime.now(timezone.utc) - timedelta(seconds=60) + assert token_manager._is_token_valid() is False + + def test_is_token_valid_near_expiry(self, token_manager): + """Test token validity when token is near expiry (within buffer)""" + token_manager._access_token = "test_token" + token_manager._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=20) + assert token_manager._is_token_valid() is False + + def test_is_token_valid_good_token(self, token_manager): + """Test token validity when token is valid""" + token_manager._access_token = "test_token" + token_manager._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=300) # 5 minutes + assert token_manager._is_token_valid() is True + + @patch('requests.post') + def test_obtain_bearer_token_success(self, mock_post, token_manager): + """Test successful token acquisition""" + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = { + 'access_token': 'test_access_token_123', + 'expires_in': 300 + } + mock_post.return_value = mock_response + + token = token_manager._obtain_bearer_token() + + assert token == 'test_access_token_123' + assert token_manager._token_expires_at is not None + + mock_post.assert_called_once_with( + "https://test.example.com/oauth2/token", + data={ + 'grant_type': 'client_credentials', + 'client_id': 'test_client_id', + 'client_secret': 'test_client_secret' + }, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + timeout=30 + ) + + @patch('requests.post') + def test_obtain_bearer_token_http_error(self, mock_post, token_manager): + """Test token acquisition with HTTP error""" + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("401 Unauthorized") + mock_post.return_value = mock_response + + with pytest.raises(requests.exceptions.HTTPError): + token_manager._obtain_bearer_token() + + @patch('requests.post') + def test_obtain_bearer_token_invalid_response(self, mock_post, token_manager): + """Test token acquisition with invalid JSON response""" + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = {'invalid': 'response'} # Missing access_token + mock_post.return_value = mock_response + + with pytest.raises(KeyError): + token_manager._obtain_bearer_token() + + @patch.object(AuthTokenManager, '_obtain_bearer_token') + def test_get_valid_token_refresh_needed(self, mock_obtain, token_manager): + """Test getting valid token when refresh is needed""" + mock_obtain.return_value = 'new_token_123' + + # Set up expired token + token_manager._access_token = 'old_token' + token_manager._token_expires_at = datetime.now(timezone.utc) - timedelta(seconds=60) + + token = token_manager._get_valid_token() + + assert token == 'new_token_123' + mock_obtain.assert_called_once() + + def test_get_valid_token_no_refresh_needed(self, token_manager): + """Test getting valid token when no refresh is needed""" + # Set up valid token + token_manager._access_token = 'valid_token' + token_manager._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=300) + + with patch.object(token_manager, '_obtain_bearer_token') as mock_obtain: + token = token_manager._get_valid_token() + + assert token == 'valid_token' + mock_obtain.assert_not_called() + + @patch.object(AuthTokenManager, '_get_valid_token') + def test_get_request_headers_with_token(self, mock_get_token, token_manager): + """Test getting request headers with Bearer token""" + mock_get_token.return_value = 'test_bearer_token' + + headers = token_manager.get_request_headers() + + expected_headers = { + "Content-Type": "text/xml; charset=utf-8", + "Authorization": "Bearer test_bearer_token" + } + assert headers == expected_headers + + @patch.object(AuthTokenManager, '_get_valid_token') + def test_get_request_headers_no_token(self, mock_get_token, token_manager): + """Test getting request headers without Bearer token""" + mock_get_token.return_value = None + + headers = token_manager.get_request_headers() + + expected_headers = {"Content-Type": "text/xml; charset=utf-8"} + assert headers == expected_headers + assert "Authorization" not in headers + + +class TestShapeshifterServiceOAuth: + """Test suite for ShapeshifterService OAuth2 integration""" + + def test_service_with_oauth_config(self): + """Test service initialization with OAuth2 configuration""" + service = ShapeshifterService( + sender_domain="test.example.com", + signing_key="test_signing_key", + oauth_token_endpoint="https://oauth.example.com/token", + oauth_client_id="test_client", + oauth_client_secret="test_secret", + token_refresh_buffer=60 + ) + + assert service.auth_token_manager is not None + assert service.auth_token_manager.oauth_token_endpoint == "https://oauth.example.com/token" + assert service.auth_token_manager.oauth_client_id == "test_client" + assert service.auth_token_manager.oauth_client_secret == "test_secret" + assert service.auth_token_manager.token_refresh_buffer == 60 + + def test_service_without_oauth_config(self): + """Test service initialization without OAuth2 configuration""" + service = ShapeshifterService( + sender_domain="test.example.com", + signing_key="test_signing_key", + oauth_token_endpoint=None, + oauth_client_id=None, + oauth_client_secret=None + ) + + assert service.auth_token_manager is None + + @patch('shapeshifter_uftp.service.base_service.client_map') + @patch('shapeshifter_uftp.service.base_service.transport') + def test_get_client_with_oauth(self, mock_transport, mock_client_map): + """Test _get_client method passes OAuth token manager""" + # Setup mocks + mock_client_class = Mock() + mock_client_map.__getitem__.return_value = mock_client_class + mock_transport.get_endpoint.return_value = "https://recipient.example.com/api" + mock_transport.get_key.return_value = "recipient_public_key" + + # Create service with OAuth2 + service = ShapeshifterService( + sender_domain="test.example.com", + signing_key="test_signing_key", + oauth_token_endpoint="https://oauth.example.com/token", + oauth_client_id="test_client", + oauth_client_secret="test_secret" + ) + service.sender_role = "AGR" # Set sender role for client_map lookup + + # Call _get_client + client = service._get_client("recipient.example.com", "DSO") + + # Verify client was created with oauth_token_manager + mock_client_class.assert_called_once() + call_kwargs = mock_client_class.call_args[1] + assert call_kwargs['oauth_token_manager'] == service.auth_token_manager + + +class TestShapeshifterClientOAuth: + """Test suite for ShapeshifterClient OAuth2 integration""" + + @pytest.fixture + def mock_token_manager(self): + """Fixture for mock token manager""" + return Mock(spec=AuthTokenManager) + + @pytest.fixture + def client_with_oauth(self, mock_token_manager): + """Fixture for client with OAuth2""" + client = ShapeshifterClient( + sender_domain="sender.example.com", + recipient_domain="recipient.example.com", + recipient_endpoint="https://recipient.example.com/api", + signing_key="test_signing_key", + recipient_signing_key="recipient_public_key", + oauth_token_manager=mock_token_manager + ) + client.sender_role = "test_sender" + client.recipient_role = "test_recipient" + return client + + @pytest.fixture + def client_without_oauth(self): + """Fixture for client without OAuth2""" + client = ShapeshifterClient( + sender_domain="sender.example.com", + recipient_domain="recipient.example.com", + recipient_endpoint="https://recipient.example.com/api", + signing_key="test_signing_key", + recipient_signing_key="recipient_public_key" + ) + client.sender_role = "test_sender" + client.recipient_role = "test_recipient" + return client + + @pytest.fixture + def mock_payload_message(self): + """Fixture for mock PayloadMessage""" + mock_message = Mock(spec=PayloadMessage) + # Set required attributes that _send_message expects + mock_message.__class__.__name__ = "TestMessage" + mock_message.version = None + mock_message.sender_domain = None + mock_message.recipient_domain = None + mock_message.time_stamp = None + mock_message.message_id = None + mock_message.conversation_id = None + return mock_message + + @patch('requests.post') + @patch('shapeshifter_uftp.client.base_client.transport') + def test_send_message_with_oauth_success(self, mock_transport, mock_post, client_with_oauth, mock_token_manager, mock_payload_message): + """Test _send_message with successful OAuth2 token""" + # Setup mocks + mock_token_manager.get_request_headers.return_value = { + "Content-Type": "text/xml; charset=utf-8", + "Authorization": "Bearer test_token_123" + } + + mock_transport.seal_message.return_value = "sealed_message" + mock_transport.to_xml.return_value = "message" + mock_transport.parser.from_bytes.return_value = Mock(body="sealed_response") + mock_transport.unseal_message.return_value = "unsealed_response" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b"response" + mock_post.return_value = mock_response + + # Create a mock message + # mock_message = Mock(spec=PayloadMessage) + # mock_message.__class__.__name__ = "TestMessage" + + # Call _send_message + result = client_with_oauth._send_message(mock_payload_message) + + # Verify OAuth2 headers were used + mock_token_manager.get_request_headers.assert_called_once() + mock_post.assert_called_once() + + # Check that the Authorization header was included + call_kwargs = mock_post.call_args[1] + assert call_kwargs['headers']['Authorization'] == "Bearer test_token_123" + + @patch('requests.post') + @patch('shapeshifter_uftp.client.base_client.transport') + def test_send_message_oauth_failure_fallback(self, mock_transport, mock_post, client_with_oauth, mock_token_manager, mock_payload_message): + """Test _send_message falls back to basic headers when OAuth2 fails""" + # Setup OAuth2 to fail + mock_token_manager.get_request_headers.side_effect = Exception("OAuth2 failed") + + mock_transport.seal_message.return_value = "sealed_message" + mock_transport.to_xml.return_value = "message" + mock_transport.parser.from_bytes.return_value = Mock(body="sealed_response") + mock_transport.unseal_message.return_value = "unsealed_response" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b"response" + mock_post.return_value = mock_response + + # # Create a mock message + # mock_message = Mock() + # mock_message.__class__.__name__ = "TestMessage" + + # Call _send_message + result = client_with_oauth._send_message(mock_payload_message) + + # Verify fallback headers were used + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + assert call_kwargs['headers'] == {"Content-Type": "text/xml; charset=utf-8"} + assert "Authorization" not in call_kwargs['headers'] + + @patch('requests.post') + @patch('shapeshifter_uftp.client.base_client.transport') + def test_send_message_without_oauth(self, mock_transport, mock_post, client_without_oauth, mock_payload_message): + """Test _send_message without OAuth2 token manager""" + mock_transport.seal_message.return_value = "sealed_message" + mock_transport.to_xml.return_value = "message" + mock_transport.parser.from_bytes.return_value = Mock(body="sealed_response") + mock_transport.unseal_message.return_value = "unsealed_response" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b"response" + mock_post.return_value = mock_response + + # Call _send_message + result = client_without_oauth._send_message(mock_payload_message) + + # Verify basic headers were used + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + assert call_kwargs['headers'] == {"Content-Type": "text/xml; charset=utf-8"} + assert "Authorization" not in call_kwargs['headers'] + + +class TestTokenRefreshLogic: + """Test token refresh timing and logic""" + + @pytest.fixture + def token_manager(self): + return AuthTokenManager( + oauth_token_endpoint="https://test.example.com/oauth2/token", + oauth_client_id="test_client_id", + oauth_client_secret="test_client_secret", + token_refresh_buffer=30 + ) + + @pytest.mark.parametrize("expires_in_seconds,refresh_buffer,should_refresh", [ + (300, 30, False), # 5 minutes left, 30s buffer - no refresh needed + (25, 30, True), # 25 seconds left, 30s buffer - refresh needed + (30, 30, True), # Exactly at buffer - refresh needed + (31, 30, False), # Just over buffer - no refresh needed + (0, 30, True), # Expired - refresh needed + ]) + def test_token_refresh_timing(self, token_manager, expires_in_seconds, refresh_buffer, should_refresh): + """Test token refresh timing with different scenarios""" + token_manager.token_refresh_buffer = refresh_buffer + token_manager._access_token = "test_token" + token_manager._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds) + + result = token_manager._is_token_valid() + assert result != should_refresh # should_refresh means token is NOT valid + + +class TestErrorHandling: + """Test error handling scenarios""" + + @pytest.fixture + def token_manager(self): + return AuthTokenManager( + oauth_token_endpoint="https://test.example.com/oauth2/token", + oauth_client_id="test_client_id", + oauth_client_secret="test_client_secret", + token_refresh_buffer=30 + ) + + @patch('requests.post') + def test_network_timeout(self, mock_post, token_manager): + """Test handling of network timeouts""" + mock_post.side_effect = requests.exceptions.Timeout("Request timed out") + + with pytest.raises(requests.exceptions.Timeout): + token_manager._obtain_bearer_token() + + @patch('requests.post') + def test_connection_error(self, mock_post, token_manager): + """Test handling of connection errors""" + mock_post.side_effect = requests.exceptions.ConnectionError("Connection failed") + + with pytest.raises(requests.exceptions.ConnectionError): + token_manager._obtain_bearer_token() + + @patch('requests.post') + def test_invalid_json_response(self, mock_post, token_manager): + """Test handling of invalid JSON response""" + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + mock_post.return_value = mock_response + + with pytest.raises(json.JSONDecodeError): + token_manager._obtain_bearer_token() + + +# Pytest configuration +@pytest.fixture(scope="session") +def setup_logging(): + """Setup logging for tests""" + import logging + logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + +# Parametrized tests for different scenarios +@pytest.mark.parametrize("endpoint,client_id,secret,expected_configured", [ + ("https://oauth.example.com/token", "client123", "secret456", True), + ("", "client123", "secret456", False), + ("https://oauth.example.com/token", "", "secret456", False), + ("https://oauth.example.com/token", "client123", "", False), + (None, None, None, False), +]) +def test_oauth_configuration_scenarios(endpoint, client_id, secret, expected_configured): + """Test different OAuth2 configuration scenarios""" + if endpoint is None: + # Handle case where service is created without OAuth2 params + service = ShapeshifterService( + sender_domain="test.example.com", + signing_key="test_key", + oauth_token_endpoint=None, + oauth_client_id=None, + oauth_client_secret=None + ) + assert (service.auth_token_manager is not None) == expected_configured + else: + token_manager = AuthTokenManager( + oauth_token_endpoint=endpoint, + oauth_client_id=client_id, + oauth_client_secret=secret + ) + assert token_manager._is_oauth_configured() == expected_configured \ No newline at end of file