Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/google/adk/agents/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class LlmCallsLimitExceededError(Exception):
"""Error thrown when the number of LLM calls exceed the limit."""


class ToolIterationsLimitExceededError(Exception):
"""Error thrown when the number of tool iterations exceed the limit."""


class RealtimeCacheEntry(BaseModel):
"""Store audio data chunks for caching before flushing."""

Expand Down Expand Up @@ -76,6 +80,9 @@ class _InvocationCostManager(BaseModel):
_number_of_llm_calls: int = 0
"""A counter that keeps track of number of llm calls made."""

_number_of_tool_iterations: int = 0
"""A counter that keeps track of consecutive tool calling iterations in current agent call."""

def increment_and_enforce_llm_calls_limit(
self, run_config: Optional[RunConfig]
):
Expand All @@ -94,6 +101,29 @@ def increment_and_enforce_llm_calls_limit(
f" `{run_config.max_llm_calls}` exceeded"
)

def increment_and_enforce_tool_iterations_limit(
self, run_config: Optional[RunConfig]
):
"""Increments _number_of_tool_iterations and enforces the limit."""
# We first increment the counter and then check the conditions.
self._number_of_tool_iterations += 1

if (
run_config
and run_config.max_tool_iterations > 0
and self._number_of_tool_iterations > run_config.max_tool_iterations
):
# We only enforce the limit if the limit is a positive number.
raise ToolIterationsLimitExceededError(
"Max number of tool iterations limit of"
f" `{run_config.max_tool_iterations}` exceeded. This prevents"
" infinite loops when using FunctionCallingConfig mode='ANY'."
)

def reset_tool_iterations_counter(self):
"""Resets the tool iterations counter. Called when agent provides final response."""
self._number_of_tool_iterations = 0


class InvocationContext(BaseModel):
"""An invocation context represents the data of a single invocation of an agent.
Expand Down Expand Up @@ -316,6 +346,33 @@ def increment_llm_call_count(
self.run_config
)

def increment_tool_iteration_count(
self,
):
"""Tracks number of tool calling iterations in the current agent call.

This method should be called each time the agent makes an LLM call that
returns function calls, to prevent infinite loops in FunctionCallingConfig
mode="ANY" scenarios.

Raises:
ToolIterationsLimitExceededError: If number of tool iterations exceed
the set threshold.
"""
self._invocation_cost_manager.increment_and_enforce_tool_iterations_limit(
self.run_config
)

def reset_tool_iteration_count(
self,
):
"""Resets the tool iterations counter.

This should be called when the agent provides a final response (not tool calls),
as it indicates the tool calling loop has completed successfully.
"""
self._invocation_cost_manager.reset_tool_iterations_counter()

@property
def app_name(self) -> str:
return self.session.app_name
Expand Down
34 changes: 34 additions & 0 deletions src/google/adk/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,25 @@ class RunConfig(BaseModel):
- Less than or equal to 0: This allows for unbounded number of llm calls.
"""

max_tool_iterations: int = 50
"""
A limit on the number of consecutive tool calling iterations in a single agent call.

This prevents infinite loops when FunctionCallingConfig mode="ANY" is used,
where the model could keep calling tools indefinitely without providing a
final response. An iteration is counted each time the agent calls the LLM and
it returns function calls (regardless of how many functions are called).

Valid Values:
- More than 0 and less than sys.maxsize: The bound on the number of tool
iterations is enforced, if the value is set in this range.
- Less than or equal to 0: This allows for unbounded number of tool iterations.

Note: This is different from max_llm_calls which limits total LLM calls across
the entire invocation. max_tool_iterations limits consecutive tool-calling
cycles within a single agent's execution flow.
"""

custom_metadata: Optional[dict[str, Any]] = None
"""Custom metadata for the current invocation."""

Expand Down Expand Up @@ -284,3 +303,18 @@ def validate_max_llm_calls(cls, value: int) -> int:
)

return value

@field_validator('max_tool_iterations', mode='after')
@classmethod
def validate_max_tool_iterations(cls, value: int) -> int:
if value == sys.maxsize:
raise ValueError(f'max_tool_iterations should be less than {sys.maxsize}.')
elif value <= 0:
logger.warning(
'max_tool_iterations is less than or equal to 0. This will result in'
' no enforcement on total number of tool iterations that will be made'
' for an agent call. This may not be ideal, as this could result in'
' infinite loops when using FunctionCallingConfig mode="ANY".',
)

return value
11 changes: 11 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,8 @@ async def _postprocess_async(

# Handles function calls.
if model_response_event.get_function_calls():
# Increment tool iteration counter to prevent infinite loops (Issue #4179)
invocation_context.increment_tool_iteration_count()

if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING):
# In progressive SSE streaming mode stage 1, we skip partial FC events
Expand All @@ -567,6 +569,9 @@ async def _postprocess_async(
) as agen:
async for event in agen:
yield event
else:
# No function calls means we got a final response, reset counter
invocation_context.reset_tool_iteration_count()

async def _postprocess_live(
self,
Expand Down Expand Up @@ -649,6 +654,9 @@ async def _postprocess_live(

# Handles function calls.
if model_response_event.get_function_calls():
# Increment tool iteration counter to prevent infinite loops (Issue #4179)
invocation_context.increment_tool_iteration_count()

function_response_event = await functions.handle_function_calls_live(
invocation_context, model_response_event, llm_request.tools_dict
)
Expand All @@ -666,6 +674,9 @@ async def _postprocess_live(
)
)
yield final_event
else:
# No function calls means we got a final response, reset counter
invocation_context.reset_tool_iteration_count()

async def _postprocess_run_processors_async(
self, invocation_context: InvocationContext, llm_response: LlmResponse
Expand Down
113 changes: 113 additions & 0 deletions tests/unittests/flows/llm_flows/test_tool_iteration_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for tool iteration limit to prevent infinite loops (Issue #4179)."""

import pytest
from google.adk.agents.invocation_context import ToolIterationsLimitExceededError
from google.adk.agents.llm_agent import Agent
from google.adk.agents.run_config import RunConfig
from google.genai import types

from ... import testing_utils


@pytest.mark.asyncio
async def test_default_max_tool_iterations_value():
"""Test that the default max_tool_iterations is 50."""
run_config = RunConfig()
assert run_config.max_tool_iterations == 50


@pytest.mark.asyncio
async def test_increment_tool_iteration_count():
"""Test that tool iteration counter increments and enforces limit."""
agent = Agent(name='test_agent')
run_config = RunConfig(max_tool_iterations=3)

invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test', run_config=run_config
)

# Should not raise for first 3 increments
invocation_context.increment_tool_iteration_count() # 1
invocation_context.increment_tool_iteration_count() # 2
invocation_context.increment_tool_iteration_count() # 3

# 4th increment should raise ToolIterationsLimitExceededError
with pytest.raises(ToolIterationsLimitExceededError) as exc_info:
invocation_context.increment_tool_iteration_count() # 4 - exceeds limit

assert 'Max number of tool iterations limit of' in str(exc_info.value)
assert '3' in str(exc_info.value)


@pytest.mark.asyncio
async def test_reset_tool_iteration_count():
"""Test that tool iteration counter resets properly."""
agent = Agent(name='test_agent')
run_config = RunConfig(max_tool_iterations=2)

invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test', run_config=run_config
)

# First cycle: increment twice
invocation_context.increment_tool_iteration_count() # 1
invocation_context.increment_tool_iteration_count() # 2

# Reset the counter
invocation_context.reset_tool_iteration_count()

# Should not raise after reset - can increment again
invocation_context.increment_tool_iteration_count() # 1 (reset)
invocation_context.increment_tool_iteration_count() # 2 (reset)

# 3rd increment should raise
with pytest.raises(ToolIterationsLimitExceededError):
invocation_context.increment_tool_iteration_count() # 3 - exceeds limit


@pytest.mark.asyncio
async def test_max_tool_iterations_disabled():
"""Test that setting max_tool_iterations to 0 disables enforcement."""
agent = Agent(name='test_agent')
run_config = RunConfig(max_tool_iterations=0)

invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test', run_config=run_config
)

# Should not raise even after many increments when limit is disabled
for _ in range(100):
invocation_context.increment_tool_iteration_count()

# No exception raised - test passes


@pytest.mark.asyncio
async def test_max_tool_iterations_validator():
"""Test that RunConfig validator warns about disabled limit."""
import logging
import warnings

# Setting to 0 should trigger a warning
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
run_config = RunConfig(max_tool_iterations=0)
assert run_config.max_tool_iterations == 0

# Setting to positive value should not raise
run_config = RunConfig(max_tool_iterations=50)
assert run_config.max_tool_iterations == 50
Comment on lines +100 to +113
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of test_max_tool_iterations_validator does not correctly test for the warning message. The validator uses logger.warning, which is not captured by warnings.catch_warnings by default. To properly test that the warning is logged, you should use the caplog fixture provided by pytest. This will ensure the test is robust and correctly verifies the validator's behavior.

Suggested change
async def test_max_tool_iterations_validator():
"""Test that RunConfig validator warns about disabled limit."""
import logging
import warnings
# Setting to 0 should trigger a warning
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
run_config = RunConfig(max_tool_iterations=0)
assert run_config.max_tool_iterations == 0
# Setting to positive value should not raise
run_config = RunConfig(max_tool_iterations=50)
assert run_config.max_tool_iterations == 50
async def test_max_tool_iterations_validator(caplog):
"""Test that RunConfig validator warns about disabled limit."""
import logging
# Setting to 0 should trigger a warning
with caplog.at_level(logging.WARNING):
run_config = RunConfig(max_tool_iterations=0)
assert run_config.max_tool_iterations == 0
assert 'max_tool_iterations is less than or equal to 0' in caplog.text
# Setting to positive value should not raise or log a warning
caplog.clear()
run_config = RunConfig(max_tool_iterations=50)
assert run_config.max_tool_iterations == 50
assert not caplog.text