Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/lerobot/policies/pi05/modeling_pi05.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import builtins
import logging
import math
import threading
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal
Expand Down Expand Up @@ -1028,9 +1029,36 @@ def _fix_pytorch_state_dict_keys(
def get_optim_params(self) -> dict:
return self.parameters()

def _new_action_queue(self) -> deque:
"""Create a fresh action queue honoring n_action_steps."""
return deque(maxlen=self.config.n_action_steps)

def _get_thread_action_queue(self) -> deque:
"""Return the action queue scoped to the current thread."""
if not hasattr(self, "_thread_local"):
self._thread_local = threading.local()
action_queue = getattr(self._thread_local, "action_queue", None)
if action_queue is None:
action_queue = self._new_action_queue()
self._thread_local.action_queue = action_queue
return action_queue

@property
def _action_queue(self) -> deque:
"""Expose the thread-local action queue (backwards compatible attribute)."""
return self._get_thread_action_queue()

@_action_queue.setter
def _action_queue(self, queue: deque) -> None:
if not hasattr(self, "_thread_local"):
self._thread_local = threading.local()

self._thread_local.action_queue = queue

Comment on lines +1032 to +1057
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider adding a module-level or class-level docstring explaining the thread-safety approach. This would help future maintainers understand why _action_queue is implemented as a property with thread-local storage. For example:

"""
Thread-Safety:
    This policy uses thread-local storage for the action queue to support
    multi-threaded evaluation scenarios (e.g., env.max_parallel_tasks > 1).
    Each thread maintains its own action queue to prevent cross-contamination
    of actions between parallel environments.
"""

Copilot uses AI. Check for mistakes.
def reset(self):
"""Reset internal state - called when environment resets."""
self._action_queue = deque(maxlen=self.config.n_action_steps)
self._thread_local = threading.local()
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reset() method unconditionally creates a new threading.local() object, which defeats the purpose of thread-local storage. This will clear the action queue for ALL threads, not just the current thread.

Instead, only the current thread's action queue should be reset. Consider this approach:

def reset(self):
    """Reset internal state - called when environment resets."""
    if not hasattr(self, "_thread_local"):
        self._thread_local = threading.local()
    self._action_queue = self._new_action_queue()
    self._queues = {
        ACTION: deque(maxlen=self.config.n_action_steps),
    }

This ensures that _thread_local is initialized only once per policy instance, while still resetting the current thread's action queue.

Suggested change
self._thread_local = threading.local()
if not hasattr(self, "_thread_local"):
self._thread_local = threading.local()

Copilot uses AI. Check for mistakes.
self._action_queue = self._new_action_queue()
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
Comment on lines 1062 to 1064
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _queues attribute is not thread-safe. Similar to _action_queue, this dictionary is shared across all threads and could cause the same cross-contamination issues when env.max_parallel_tasks > 1.

Consider applying the same thread-local pattern to _queues or clarifying its intended usage. If it's not actively used elsewhere in the codebase, it might be safe to leave it as-is, but if it's used in multi-threaded contexts, it should also be made thread-local.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -1160,4 +1188,4 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
}

return loss, loss_dict
return loss, loss_dict
Loading