Skip to content

Commit 3a3be87

Browse files
authored
Add should_skip method to Middleware class for conditional request processing (#15)
1 parent 2f33a64 commit 3a3be87

File tree

4 files changed

+535
-1
lines changed

4 files changed

+535
-1
lines changed

py_spring_core/core/application/py_spring_application.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@ def __init_controllers(self) -> None:
228228
controller._register_decorated_routes(routes)
229229
router = controller.get_router()
230230
self.fastapi.include_router(router)
231-
self.__init_middlewares()
232231
logger.debug(f"[CONTROLLER INIT] Controller {name} initialized")
233232

234233
def __init_middlewares(self) -> None:
@@ -298,6 +297,7 @@ def run(self) -> None:
298297
self.__configure_logging()
299298
self.__init_app()
300299
self.__init_controllers()
300+
self.__init_middlewares()
301301
if self.app_config.server_config.enabled:
302302
self.__run_server()
303303
finally:

py_spring_core/core/entities/middlewares/middleware.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@ class Middleware(BaseHTTPMiddleware):
1111
Simpler to use, only need to implement the process_request method
1212
"""
1313

14+
def should_skip(self, request: Request) -> bool:
15+
"""
16+
Method to determine if the middleware should be skipped
17+
18+
Args:
19+
request: FastAPI request object
20+
21+
Returns:
22+
bool: True if the middleware should be skipped, False otherwise, default is False
23+
"""
24+
return False
25+
1426
@abstractmethod
1527
async def process_request(self, request: Request) -> Response | None:
1628
"""
@@ -30,7 +42,11 @@ async def dispatch(
3042
) -> Response:
3143
"""
3244
Middleware dispatch method, automatically called by FastAPI
45+
If should_skip returns True, the middleware will be skipped
3346
"""
47+
if self.should_skip(request):
48+
return await call_next(request)
49+
3450
# First execute custom request processing logic
3551
response = await self.process_request(request)
3652

tests/test_middleware.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,233 @@ async def process_request(self, request: Request) -> Response | None:
140140

141141
assert received_request == mock_request
142142

143+
def test_should_skip_default_returns_false(self, mock_request):
144+
"""
145+
Test that should_skip method returns False by default.
146+
147+
This test verifies that:
148+
1. The default implementation of should_skip returns False
149+
2. This allows the middleware to process all requests by default
150+
"""
151+
middleware = Middleware(app=Mock())
152+
result = middleware.should_skip(mock_request)
153+
assert result is False
154+
155+
def test_should_skip_can_be_overridden(self, mock_request):
156+
"""
157+
Test that should_skip method can be overridden in subclasses.
158+
159+
This test verifies that:
160+
1. Subclasses can override should_skip to provide custom skip logic
161+
2. The overridden method is called with the correct request parameter
162+
"""
163+
class SkippingMiddleware(Middleware):
164+
def should_skip(self, request: Request) -> bool:
165+
return request.method == "GET"
166+
167+
middleware = SkippingMiddleware(app=Mock())
168+
result = middleware.should_skip(mock_request)
169+
assert result is True
170+
171+
@pytest.mark.asyncio
172+
async def test_dispatch_skips_middleware_when_should_skip_returns_true(
173+
self, mock_request, mock_call_next
174+
):
175+
"""
176+
Test that dispatch skips middleware processing when should_skip returns True.
177+
178+
This test verifies that:
179+
1. When should_skip returns True, process_request is not called
180+
2. The request is passed directly to call_next
181+
3. The response from call_next is returned
182+
"""
183+
expected_response = Response(content="skipped response", status_code=200)
184+
mock_call_next.return_value = expected_response
185+
186+
class SkippingMiddleware(Middleware):
187+
def should_skip(self, request: Request) -> bool:
188+
return True
189+
190+
async def process_request(self, request: Request) -> Response | None:
191+
# This should never be called when should_skip returns True
192+
raise AssertionError("process_request should not be called")
193+
194+
middleware = SkippingMiddleware(app=Mock())
195+
result = await middleware.dispatch(mock_request, mock_call_next)
196+
197+
# Verify call_next was called with the request
198+
mock_call_next.assert_called_once_with(mock_request)
199+
# Verify the response from call_next is returned
200+
assert result == expected_response
201+
202+
@pytest.mark.asyncio
203+
async def test_dispatch_processes_middleware_when_should_skip_returns_false(
204+
self, mock_request, mock_call_next
205+
):
206+
"""
207+
Test that dispatch processes middleware when should_skip returns False.
208+
209+
This test verifies that:
210+
1. When should_skip returns False, process_request is called
211+
2. The middleware processing logic is executed
212+
3. The normal dispatch flow continues
213+
"""
214+
expected_response = Response(content="processed response", status_code=200)
215+
mock_call_next.return_value = expected_response
216+
217+
process_request_called = False
218+
219+
class ProcessingMiddleware(Middleware):
220+
def should_skip(self, request: Request) -> bool:
221+
return False
222+
223+
async def process_request(self, request: Request) -> Response | None:
224+
nonlocal process_request_called
225+
process_request_called = True
226+
return None
227+
228+
middleware = ProcessingMiddleware(app=Mock())
229+
result = await middleware.dispatch(mock_request, mock_call_next)
230+
231+
# Verify process_request was called
232+
assert process_request_called is True
233+
# Verify call_next was called
234+
mock_call_next.assert_called_once_with(mock_request)
235+
# Verify the response from call_next is returned
236+
assert result == expected_response
237+
238+
def test_should_skip_receives_correct_request_parameter(self, mock_request):
239+
"""
240+
Test that should_skip method receives the correct request parameter.
241+
242+
This test verifies that:
243+
1. The should_skip method receives the exact same request object
244+
2. The request parameter is passed correctly
245+
"""
246+
received_request = None
247+
248+
class TestMiddleware(Middleware):
249+
def should_skip(self, request: Request) -> bool:
250+
nonlocal received_request
251+
received_request = request
252+
return False
253+
254+
middleware = TestMiddleware(app=Mock())
255+
middleware.should_skip(mock_request)
256+
257+
assert received_request == mock_request
258+
259+
@pytest.mark.asyncio
260+
async def test_dispatch_with_conditional_skip_logic(self, mock_request, mock_call_next):
261+
"""
262+
Test dispatch with conditional skip logic based on request properties.
263+
264+
This test verifies that:
265+
1. should_skip can use request properties to make skip decisions
266+
2. The skip logic works correctly in the dispatch flow
267+
3. Both skip and process paths work as expected
268+
"""
269+
expected_response = Response(content="test response", status_code=200)
270+
mock_call_next.return_value = expected_response
271+
272+
class ConditionalMiddleware(Middleware):
273+
def should_skip(self, request: Request) -> bool:
274+
# Skip GET requests, process others
275+
return request.method == "GET"
276+
277+
async def process_request(self, request: Request) -> Response | None:
278+
# This should only be called for non-GET requests
279+
return Response(content="processed", status_code=202)
280+
281+
middleware = ConditionalMiddleware(app=Mock())
282+
283+
# Test with GET request (should skip)
284+
mock_request.method = "GET"
285+
result = await middleware.dispatch(mock_request, mock_call_next)
286+
287+
assert result == expected_response
288+
mock_call_next.assert_called_once_with(mock_request)
289+
290+
# Reset mock for next test
291+
mock_call_next.reset_mock()
292+
mock_call_next.return_value = expected_response
293+
294+
# Test with POST request (should process)
295+
mock_request.method = "POST"
296+
result = await middleware.dispatch(mock_request, mock_call_next)
297+
298+
assert result.body == b"processed"
299+
assert result.status_code == 202
300+
mock_call_next.assert_not_called()
301+
302+
@pytest.mark.asyncio
303+
async def test_dispatch_with_url_based_skip_logic(self, mock_request, mock_call_next):
304+
"""
305+
Test dispatch with URL-based skip logic.
306+
307+
This test verifies that:
308+
1. should_skip can use request URL to make skip decisions
309+
2. URL-based filtering works correctly
310+
3. The middleware processes only relevant requests
311+
"""
312+
expected_response = Response(content="test response", status_code=200)
313+
mock_call_next.return_value = expected_response
314+
315+
class URLBasedMiddleware(Middleware):
316+
def should_skip(self, request: Request) -> bool:
317+
# Skip requests to /health endpoint
318+
return str(request.url).endswith("/health")
319+
320+
async def process_request(self, request: Request) -> Response | None:
321+
return Response(content="processed", status_code=202)
322+
323+
middleware = URLBasedMiddleware(app=Mock())
324+
325+
# Test with health endpoint (should skip)
326+
mock_request.url = "http://test.com/health"
327+
result = await middleware.dispatch(mock_request, mock_call_next)
328+
329+
assert result == expected_response
330+
mock_call_next.assert_called_once_with(mock_request)
331+
332+
# Reset mock for next test
333+
mock_call_next.reset_mock()
334+
mock_call_next.return_value = expected_response
335+
336+
# Test with other endpoint (should process)
337+
mock_request.url = "http://test.com/api/users"
338+
result = await middleware.dispatch(mock_request, mock_call_next)
339+
340+
assert result.body == b"processed"
341+
assert result.status_code == 202
342+
mock_call_next.assert_not_called()
343+
344+
def test_should_skip_method_signature(self):
345+
"""
346+
Test that should_skip method has the correct signature.
347+
348+
This test verifies that:
349+
1. should_skip is an instance method
350+
2. It takes a Request parameter
351+
3. It returns a boolean value
352+
"""
353+
middleware = Middleware(app=Mock())
354+
355+
# Check that should_skip is a method
356+
assert hasattr(middleware, 'should_skip')
357+
assert callable(middleware.should_skip)
358+
359+
# Check that it's an instance method (not a class method or static method)
360+
import inspect
361+
sig = inspect.signature(middleware.should_skip)
362+
params = list(sig.parameters.keys())
363+
364+
# Should have 'request' parameter (self is automatically handled by Python)
365+
assert params == ['request']
366+
367+
# Check return type annotation
368+
assert sig.return_annotation == bool
369+
143370

144371
class TestMiddlewareRegistry:
145372
"""Test suite for the MiddlewareRegistry abstract class."""

0 commit comments

Comments
 (0)