Skip to content

Commit 02792b2

Browse files
committed
add tests for run hooks
1 parent cdcfa4e commit 02792b2

File tree

1 file changed

+223
-0
lines changed

1 file changed

+223
-0
lines changed

tests/test_run_hooks.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
from collections import defaultdict
2+
from typing import Any, Optional
3+
4+
import pytest
5+
6+
from agents.agent import Agent
7+
from agents.items import ItemHelpers, ModelResponse, TResponseInputItem
8+
from agents.lifecycle import RunHooks
9+
from agents.models.interface import Model
10+
from agents.run import Runner
11+
from agents.run_context import RunContextWrapper, TContext
12+
from agents.tool import Tool
13+
from tests.test_agent_llm_hooks import AgentHooksForTests
14+
15+
from .fake_model import FakeModel
16+
from .test_responses import (
17+
get_function_tool,
18+
get_text_message,
19+
)
20+
21+
22+
class RunHooksForTests(RunHooks):
23+
def __init__(self):
24+
self.events: dict[str, int] = defaultdict(int)
25+
26+
def reset(self):
27+
self.events.clear()
28+
29+
async def on_agent_start(
30+
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
31+
) -> None:
32+
self.events["on_agent_start"] += 1
33+
34+
async def on_agent_end(
35+
self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any
36+
) -> None:
37+
self.events["on_agent_end"] += 1
38+
39+
async def on_handoff(
40+
self,
41+
context: RunContextWrapper[TContext],
42+
from_agent: Agent[TContext],
43+
to_agent: Agent[TContext],
44+
) -> None:
45+
self.events["on_handoff"] += 1
46+
47+
async def on_tool_start(
48+
self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool
49+
) -> None:
50+
self.events["on_tool_start"] += 1
51+
52+
async def on_tool_end(
53+
self,
54+
context: RunContextWrapper[TContext],
55+
agent: Agent[TContext],
56+
tool: Tool,
57+
result: str,
58+
) -> None:
59+
self.events["on_tool_end"] += 1
60+
61+
async def on_llm_start(
62+
self,
63+
context: RunContextWrapper[TContext],
64+
agent: Agent[TContext],
65+
system_prompt: Optional[str],
66+
input_items: list[TResponseInputItem],
67+
) -> None:
68+
self.events["on_llm_start"] += 1
69+
70+
async def on_llm_end(
71+
self,
72+
context: RunContextWrapper[TContext],
73+
agent: Agent[TContext],
74+
response: ModelResponse,
75+
) -> None:
76+
self.events["on_llm_end"] += 1
77+
78+
79+
# Example test using the above hooks
80+
@pytest.mark.asyncio
81+
async def test_async_run_hooks_with_llm():
82+
hooks = RunHooksForTests()
83+
model = FakeModel()
84+
85+
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
86+
# Simulate a single LLM call producing an output:
87+
model.set_next_output([get_text_message("hello")])
88+
await Runner.run(agent, input="hello", hooks=hooks)
89+
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
90+
assert hooks.events == {
91+
"on_agent_start": 1,
92+
"on_llm_start": 1,
93+
"on_llm_end": 1,
94+
"on_agent_end": 1,
95+
}
96+
97+
98+
# test_sync_run_hook_with_llm()
99+
def test_sync_run_hook_with_llm():
100+
hooks = RunHooksForTests()
101+
model = FakeModel()
102+
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
103+
# Simulate a single LLM call producing an output:
104+
model.set_next_output([get_text_message("hello")])
105+
Runner.run_sync(agent, input="hello", hooks=hooks)
106+
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
107+
assert hooks.events == {
108+
"on_agent_start": 1,
109+
"on_llm_start": 1,
110+
"on_llm_end": 1,
111+
"on_agent_end": 1,
112+
}
113+
114+
115+
# test_streamed_run_hooks_with_llm():
116+
@pytest.mark.asyncio
117+
async def test_streamed_run_hooks_with_llm():
118+
hooks = RunHooksForTests()
119+
model = FakeModel()
120+
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
121+
# Simulate a single LLM call producing an output:
122+
model.set_next_output([get_text_message("hello")])
123+
stream = Runner.run_streamed(agent, input="hello", hooks=hooks)
124+
125+
async for event in stream.stream_events():
126+
if event.type == "raw_response_event":
127+
continue
128+
if event.type == "agent_updated_stream_event":
129+
print(f"[EVENT] agent_updated → {event.new_agent.name}")
130+
elif event.type == "run_item_stream_event":
131+
item = event.item
132+
if item.type == "tool_call_item":
133+
print("[EVENT] tool_call_item")
134+
elif item.type == "tool_call_output_item":
135+
print(f"[EVENT] tool_call_output_item → {item.output}")
136+
elif item.type == "message_output_item":
137+
text = ItemHelpers.text_message_output(item)
138+
print(f"[EVENT] message_output_item → {text}")
139+
140+
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
141+
assert hooks.events == {
142+
"on_agent_start": 1,
143+
"on_llm_start": 1,
144+
"on_llm_end": 1,
145+
"on_agent_end": 1,
146+
}
147+
148+
149+
# test_async_run_hooks_with_agent_hooks_with_llm
150+
@pytest.mark.asyncio
151+
async def test_async_run_hooks_with_agent_hooks_with_llm():
152+
hooks = RunHooksForTests()
153+
agent_hooks = AgentHooksForTests()
154+
model = FakeModel()
155+
156+
agent = Agent(
157+
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=agent_hooks
158+
)
159+
# Simulate a single LLM call producing an output:
160+
model.set_next_output([get_text_message("hello")])
161+
await Runner.run(agent, input="hello", hooks=hooks)
162+
# Expect one on_agent_start, one on_llm_start, one on_llm_end, and one on_agent_end
163+
assert hooks.events == {
164+
"on_agent_start": 1,
165+
"on_llm_start": 1,
166+
"on_llm_end": 1,
167+
"on_agent_end": 1,
168+
}
169+
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
170+
assert agent_hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}
171+
172+
173+
@pytest.mark.asyncio
174+
async def test_run_hooks_llm_error_non_streaming(monkeypatch):
175+
hooks = RunHooksForTests()
176+
model = FakeModel()
177+
agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[])
178+
179+
async def boom(*args, **kwargs):
180+
raise RuntimeError("boom")
181+
182+
monkeypatch.setattr(FakeModel, "get_response", boom, raising=True)
183+
184+
with pytest.raises(RuntimeError, match="boom"):
185+
await Runner.run(agent, input="hello", hooks=hooks)
186+
187+
# Current behavior is that hooks will not fire on LLM failure
188+
assert hooks.events["on_agent_start"] == 1
189+
assert hooks.events["on_llm_start"] == 1
190+
assert hooks.events["on_llm_end"] == 0
191+
assert hooks.events["on_agent_end"] == 0
192+
193+
194+
class BoomModel(Model):
195+
async def get_response(self, *a, **k):
196+
raise AssertionError("get_response should not be called in streaming test")
197+
198+
async def stream_response(self, *a, **k): # type: ignore[override]
199+
yield {"foo": "bar"}
200+
raise RuntimeError("stream blew up")
201+
202+
203+
@pytest.mark.asyncio
204+
async def test_streamed_run_hooks_llm_error(monkeypatch):
205+
"""
206+
Verify that when the streaming path raises, we still emit on_llm_start
207+
but do NOT emit on_llm_end (current behavior), and the exception propagates.
208+
"""
209+
hooks = RunHooksForTests()
210+
agent = Agent(name="A", model=BoomModel(), tools=[get_function_tool("f", "res")], handoffs=[])
211+
212+
stream = Runner.run_streamed(agent, input="hello", hooks=hooks)
213+
214+
# Consuming the stream should surface the exception
215+
with pytest.raises(RuntimeError, match="stream blew up"):
216+
async for _ in stream.stream_events():
217+
pass
218+
219+
# Current behavior: success-only on_llm_end; ensure starts fired but ends did not.
220+
assert hooks.events["on_agent_start"] == 1
221+
assert hooks.events["on_llm_start"] == 1
222+
assert hooks.events["on_llm_end"] == 0
223+
assert hooks.events["on_agent_end"] == 0

0 commit comments

Comments
 (0)