From 6904dcbbdbf211aa2542a1e824db28985a0e914b Mon Sep 17 00:00:00 2001 From: Erik Anstine <32369420+erikanstine@users.noreply.github.com> Date: Mon, 1 Sep 2025 23:32:17 -0400 Subject: [PATCH] fix(run): fire on_llm_start / on_llm_end in Runner.run() for streaming & non-streaming (aligns with docs) (#1619) --- examples/basic/lifecycle_example.py | 41 ++++- src/agents/run.py | 63 +++++--- tests/test_run_hooks.py | 223 ++++++++++++++++++++++++++++ 3 files changed, 302 insertions(+), 25 deletions(-) create mode 100644 tests/test_run_hooks.py diff --git a/examples/basic/lifecycle_example.py b/examples/basic/lifecycle_example.py index f37380b25..941b67768 100644 --- a/examples/basic/lifecycle_example.py +++ b/examples/basic/lifecycle_example.py @@ -1,10 +1,11 @@ import asyncio import random -from typing import Any +from typing import Any, Optional from pydantic import BaseModel from agents import Agent, RunContextWrapper, RunHooks, Runner, Tool, Usage, function_tool +from agents.items import ModelResponse, TResponseInputItem class ExampleHooks(RunHooks): @@ -20,6 +21,22 @@ async def on_agent_start(self, context: RunContextWrapper, agent: Agent) -> None f"### {self.event_counter}: Agent {agent.name} started. Usage: {self._usage_to_str(context.usage)}" ) + async def on_llm_start( + self, + context: RunContextWrapper, + agent: Agent, + system_prompt: Optional[str], + input_items: list[TResponseInputItem], + ) -> None: + self.event_counter += 1 + print(f"### {self.event_counter}: LLM started. Usage: {self._usage_to_str(context.usage)}") + + async def on_llm_end( + self, context: RunContextWrapper, agent: Agent, response: ModelResponse + ) -> None: + self.event_counter += 1 + print(f"### {self.event_counter}: LLM ended. Usage: {self._usage_to_str(context.usage)}") + async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: Any) -> None: self.event_counter += 1 print( @@ -109,13 +126,21 @@ async def main() -> None: Enter a max number: 250 ### 1: Agent Start Agent started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens -### 2: Tool random_number started. Usage: 1 requests, 148 input tokens, 15 output tokens, 163 total tokens -### 3: Tool random_number ended with result 101. Usage: 1 requests, 148 input tokens, 15 output tokens, 163 total token -### 4: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 323 input tokens, 30 output tokens, 353 total tokens -### 5: Agent Multiply Agent started. Usage: 2 requests, 323 input tokens, 30 output tokens, 353 total tokens -### 6: Tool multiply_by_two started. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens -### 7: Tool multiply_by_two ended with result 202. Usage: 3 requests, 504 input tokens, 46 output tokens, 550 total tokens -### 8: Agent Multiply Agent ended with output number=202. Usage: 4 requests, 714 input tokens, 63 output tokens, 777 total tokens +### 2: LLM started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens +### 3: LLM ended. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens +### 4: Tool random_number started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens +### 5: Tool random_number ended with result 69. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens +### 6: LLM started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens +### 7: LLM ended. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens +### 8: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens +### 9: Agent Multiply Agent started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens +### 10: LLM started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens +### 11: LLM ended. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens +### 12: Tool multiply_by_two started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens +### 13: Tool multiply_by_two ended with result 138. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens +### 14: LLM started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens +### 15: LLM ended. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens +### 16: Agent Multiply Agent ended with output number=138. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens Done! """ diff --git a/src/agents/run.py b/src/agents/run.py index 742917b87..c68e41989 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -994,10 +994,16 @@ async def _run_single_turn_streamed( ) # Call hook just before the model is invoked, with the correct system_prompt. - if agent.hooks: - await agent.hooks.on_llm_start( - context_wrapper, agent, filtered.instructions, filtered.input - ) + await asyncio.gather( + hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), + ( + agent.hooks.on_llm_start( + context_wrapper, agent, filtered.instructions, filtered.input + ) + if agent.hooks + else _coro.noop_coroutine() + ), + ) # 1. Stream the output events async for event in model.stream_response( @@ -1056,8 +1062,15 @@ async def _run_single_turn_streamed( streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) # Call hook just after the model response is finalized. - if agent.hooks and final_response is not None: - await agent.hooks.on_llm_end(context_wrapper, agent, final_response) + if final_response is not None: + await asyncio.gather( + ( + agent.hooks.on_llm_end(context_wrapper, agent, final_response) + if agent.hooks + else _coro.noop_coroutine() + ), + hooks.on_llm_end(context_wrapper, agent, final_response), + ) # 2. At this point, the streaming is complete for this turn of the agent loop. if not final_response: @@ -1150,6 +1163,7 @@ async def _run_single_turn( output_schema, all_tools, handoffs, + hooks, context_wrapper, run_config, tool_use_tracker, @@ -1345,6 +1359,7 @@ async def _get_new_response( output_schema: AgentOutputSchemaBase | None, all_tools: list[Tool], handoffs: list[Handoff], + hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, tool_use_tracker: AgentToolUseTracker, @@ -1364,14 +1379,21 @@ async def _get_new_response( model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) - # If the agent has hooks, we need to call them before and after the LLM call - if agent.hooks: - await agent.hooks.on_llm_start( - context_wrapper, - agent, - filtered.instructions, # Use filtered instructions - filtered.input, # Use filtered input - ) + + # If we have run hooks, or if the agent has hooks, we need to call them before the LLM call + await asyncio.gather( + hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), + ( + agent.hooks.on_llm_start( + context_wrapper, + agent, + filtered.instructions, # Use filtered instructions + filtered.input, # Use filtered input + ) + if agent.hooks + else _coro.noop_coroutine() + ), + ) new_response = await model.get_response( system_instructions=filtered.instructions, @@ -1387,12 +1409,19 @@ async def _get_new_response( conversation_id=conversation_id, prompt=prompt_config, ) - # If the agent has hooks, we need to call them after the LLM call - if agent.hooks: - await agent.hooks.on_llm_end(context_wrapper, agent, new_response) context_wrapper.usage.add(new_response.usage) + # If we have run hooks, or if the agent has hooks, we need to call them after the LLM call + await asyncio.gather( + ( + agent.hooks.on_llm_end(context_wrapper, agent, new_response) + if agent.hooks + else _coro.noop_coroutine() + ), + hooks.on_llm_end(context_wrapper, agent, new_response), + ) + return new_response @classmethod diff --git a/tests/test_run_hooks.py b/tests/test_run_hooks.py new file mode 100644 index 000000000..988cd6dc2 --- /dev/null +++ b/tests/test_run_hooks.py @@ -0,0 +1,223 @@ +from collections import defaultdict +from typing import Any, Optional + +import pytest + +from agents.agent import Agent +from agents.items import ItemHelpers, ModelResponse, TResponseInputItem +from agents.lifecycle import RunHooks +from agents.models.interface import Model +from agents.run import Runner +from agents.run_context import RunContextWrapper, TContext +from agents.tool import Tool +from tests.test_agent_llm_hooks import AgentHooksForTests + +from .fake_model import FakeModel +from .test_responses import ( + get_function_tool, + get_text_message, +) + + +class RunHooksForTests(RunHooks): + def __init__(self): + self.events: dict[str, int] = defaultdict(int) + + def reset(self): + self.events.clear() + + async def on_agent_start( + self, context: RunContextWrapper[TContext], agent: Agent[TContext] + ) -> None: + self.events["on_agent_start"] += 1 + + async def on_agent_end( + self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any + ) -> None: + self.events["on_agent_end"] += 1 + + async def on_handoff( + self, + context: RunContextWrapper[TContext], + from_agent: Agent[TContext], + to_agent: Agent[TContext], + ) -> None: + self.events["on_handoff"] += 1 + + async def on_tool_start( + self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool + ) -> None: + self.events["on_tool_start"] += 1 + + async def on_tool_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + tool: Tool, + result: str, + ) -> None: + self.events["on_tool_end"] += 1 + + async def on_llm_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + system_prompt: Optional[str], + input_items: list[TResponseInputItem], + ) -> None: + self.events["on_llm_start"] += 1 + + async def on_llm_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + response: ModelResponse, + ) -> None: + self.events["on_llm_end"] += 1 + + +# Example test using the above hooks +@pytest.mark.asyncio +async def test_async_run_hooks_with_llm(): + hooks = RunHooksForTests() + model = FakeModel() + + agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[]) + # Simulate a single LLM call producing an output: + model.set_next_output([get_text_message("hello")]) + await Runner.run(agent, input="hello", hooks=hooks) + # Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end + assert hooks.events == { + "on_agent_start": 1, + "on_llm_start": 1, + "on_llm_end": 1, + "on_agent_end": 1, + } + + +# test_sync_run_hook_with_llm() +def test_sync_run_hook_with_llm(): + hooks = RunHooksForTests() + model = FakeModel() + agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[]) + # Simulate a single LLM call producing an output: + model.set_next_output([get_text_message("hello")]) + Runner.run_sync(agent, input="hello", hooks=hooks) + # Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end + assert hooks.events == { + "on_agent_start": 1, + "on_llm_start": 1, + "on_llm_end": 1, + "on_agent_end": 1, + } + + +# test_streamed_run_hooks_with_llm(): +@pytest.mark.asyncio +async def test_streamed_run_hooks_with_llm(): + hooks = RunHooksForTests() + model = FakeModel() + agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[]) + # Simulate a single LLM call producing an output: + model.set_next_output([get_text_message("hello")]) + stream = Runner.run_streamed(agent, input="hello", hooks=hooks) + + async for event in stream.stream_events(): + if event.type == "raw_response_event": + continue + if event.type == "agent_updated_stream_event": + print(f"[EVENT] agent_updated → {event.new_agent.name}") + elif event.type == "run_item_stream_event": + item = event.item + if item.type == "tool_call_item": + print("[EVENT] tool_call_item") + elif item.type == "tool_call_output_item": + print(f"[EVENT] tool_call_output_item → {item.output}") + elif item.type == "message_output_item": + text = ItemHelpers.text_message_output(item) + print(f"[EVENT] message_output_item → {text}") + + # Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end + assert hooks.events == { + "on_agent_start": 1, + "on_llm_start": 1, + "on_llm_end": 1, + "on_agent_end": 1, + } + + +# test_async_run_hooks_with_agent_hooks_with_llm +@pytest.mark.asyncio +async def test_async_run_hooks_with_agent_hooks_with_llm(): + hooks = RunHooksForTests() + agent_hooks = AgentHooksForTests() + model = FakeModel() + + agent = Agent( + name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=agent_hooks + ) + # Simulate a single LLM call producing an output: + model.set_next_output([get_text_message("hello")]) + await Runner.run(agent, input="hello", hooks=hooks) + # Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end + assert hooks.events == { + "on_agent_start": 1, + "on_llm_start": 1, + "on_llm_end": 1, + "on_agent_end": 1, + } + # Expect one on_start, one on_llm_start, one on_llm_end, and one on_end + assert agent_hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1} + + +@pytest.mark.asyncio +async def test_run_hooks_llm_error_non_streaming(monkeypatch): + hooks = RunHooksForTests() + model = FakeModel() + agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[]) + + async def boom(*args, **kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(FakeModel, "get_response", boom, raising=True) + + with pytest.raises(RuntimeError, match="boom"): + await Runner.run(agent, input="hello", hooks=hooks) + + # Current behavior is that hooks will not fire on LLM failure + assert hooks.events["on_agent_start"] == 1 + assert hooks.events["on_llm_start"] == 1 + assert hooks.events["on_llm_end"] == 0 + assert hooks.events["on_agent_end"] == 0 + + +class BoomModel(Model): + async def get_response(self, *a, **k): + raise AssertionError("get_response should not be called in streaming test") + + async def stream_response(self, *a, **k): + yield {"foo": "bar"} + raise RuntimeError("stream blew up") + + +@pytest.mark.asyncio +async def test_streamed_run_hooks_llm_error(monkeypatch): + """ + Verify that when the streaming path raises, we still emit on_llm_start + but do NOT emit on_llm_end (current behavior), and the exception propagates. + """ + hooks = RunHooksForTests() + agent = Agent(name="A", model=BoomModel(), tools=[get_function_tool("f", "res")], handoffs=[]) + + stream = Runner.run_streamed(agent, input="hello", hooks=hooks) + + # Consuming the stream should surface the exception + with pytest.raises(RuntimeError, match="stream blew up"): + async for _ in stream.stream_events(): + pass + + # Current behavior: success-only on_llm_end; ensure starts fired but ends did not. + assert hooks.events["on_agent_start"] == 1 + assert hooks.events["on_llm_start"] == 1 + assert hooks.events["on_llm_end"] == 0 + assert hooks.events["on_agent_end"] == 0