Skip to content

Conversation

@aryanrahar
Copy link

This PR fixes an off-by-one bug in pi0_fast.py when computing positions for the next decoded token. Prefix token positions are 0-indexed (0..L-1), so the first new token after a prefix of length L must be placed at position L (not L+1).
Fix: drop the + 1 when building positions.

Fixes #705.
Changes-
src/openpi/models/pi0_fast.py: 1-line change to compute contiguous, zero-indexed positions.
src/openpi/models/pi0_fast_positions_test.py: small regression test asserting:
next token after length L is at L
advancing step increments positions by 1

Scope / Compatibility
Affects JAX pi0-FAST decode path only.
No API changes. No impact on the PyTorch path (pi0-FAST not supported there).
No performance impact.

How I tested
uv run pytest -q src/openpi/models/pi0_fast_positions_test.py passes.
Verified the minimal repro before/after values locally.
Ran ruff + ruff-format hooks.

@jimmyt857 jimmyt857 removed their request for review October 3, 2025 01:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

potential bug in setting positions value for new token during decoding

1 participant