Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade OnlineBuffer to DataBuffer and use it in the train script. #445

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2a03007
backup wip
alexander-soare Sep 16, 2024
1140a85
backup wip
alexander-soare Sep 16, 2024
745787d
backup wip
alexander-soare Sep 16, 2024
f455f1b
backup wip
alexander-soare Sep 17, 2024
e6864f0
backup wip
alexander-soare Sep 17, 2024
d6304e1
from_huggingface_hub base test passing
alexander-soare Sep 17, 2024
edae440
Tests passing with improved API
alexander-soare Sep 18, 2024
51609fe
All test_data_buffer passing
alexander-soare Sep 18, 2024
59b7e0e
ALL tests passing
alexander-soare Sep 18, 2024
c99691c
Merge remote-tracking branch 'upstream/main' into data_buffer
alexander-soare Sep 18, 2024
d4b40f8
backup wip
alexander-soare Sep 18, 2024
731634d
ALL tests passing
alexander-soare Sep 19, 2024
6944a27
backup wip
alexander-soare Sep 19, 2024
52da98d
backup wip
alexander-soare Sep 19, 2024
05cb6ff
backup wip
alexander-soare Sep 19, 2024
7da0988
backup wip
alexander-soare Sep 19, 2024
5bc1500
ready for review
alexander-soare Sep 19, 2024
63c2320
temporarily revert file names for diff
alexander-soare Sep 19, 2024
f0848cd
remove redundant kwarg
alexander-soare Sep 20, 2024
f7f2972
revision
alexander-soare Sep 23, 2024
2d05318
backup wip
alexander-soare Sep 24, 2024
c86c755
add png option
alexander-soare Sep 25, 2024
375c5f9
from_huggingface_hub adds episodes one at at time
alexander-soare Sep 25, 2024
e88ab2d
nicer list comp
alexander-soare Sep 25, 2024
d13104d
backup wip
alexander-soare Sep 26, 2024
7dd7cb3
backup wip
alexander-soare Sep 26, 2024
2e72b6a
backup wip
alexander-soare Sep 26, 2024
90ea3df
backup wip
alexander-soare Sep 27, 2024
fb6841d
tests passing
alexander-soare Sep 27, 2024
c312d26
backup wip
alexander-soare Sep 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ build-gpu:

test-end-to-end:
${MAKE} DEVICE=$(DEVICE) test-act-ete-train
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-data-buffer
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-data-buffer-decode-video
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp
Expand Down Expand Up @@ -50,6 +52,47 @@ test-act-ete-train:
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/act/

test-act-ete-train-data-buffer:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=$(DEVICE) \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/act_buffer/ \
+use_lerobot_data_buffer=true \

test-act-ete-train-data-buffer-decode-video:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=$(DEVICE) \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/act_buffer_decode_video/ \
+use_lerobot_data_buffer=true \
+lerobot_data_buffer_decode_video=true \

test-act-ete-eval:
python lerobot/scripts/eval.py \
-p tests/outputs/act/checkpoints/000002/pretrained_model \
Expand Down
1,094 changes: 910 additions & 184 deletions lerobot/common/datasets/online_buffer.py

Large diffs are not rendered by default.

62 changes: 54 additions & 8 deletions lerobot/common/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import subprocess
import warnings
from collections import OrderedDict
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, ClassVar

import einops
import pyarrow as pa
import torch
import torchvision
Expand All @@ -33,6 +36,7 @@ def load_from_videos(
videos_dir: Path,
tolerance_s: float,
backend: str = "pyav",
to_pytorch_format: bool = True,
):
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
Expand All @@ -51,14 +55,18 @@ def load_from_videos(
raise NotImplementedError("All video paths are expected to be the same for now.")
video_path = data_dir / paths[0]

frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
frames = decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend, to_pytorch_format=to_pytorch_format
)
item[key] = frames
else:
# load one frame
timestamps = [item[key]["timestamp"]]
video_path = data_dir / item[key]["path"]

frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
frames = decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend, to_pytorch_format=to_pytorch_format
)
item[key] = frames[0]

return item
Expand All @@ -70,6 +78,7 @@ def decode_video_frames_torchvision(
tolerance_s: float,
backend: str = "pyav",
log_loaded_timestamps: bool = False,
to_pytorch_format: bool = True,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video

Expand Down Expand Up @@ -129,8 +138,8 @@ def decode_video_frames_torchvision(

reader = None

query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
query_ts = torch.tensor(timestamps, dtype=torch.float32)
loaded_ts = torch.tensor(loaded_ts, dtype=torch.float32)

# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
Expand All @@ -155,13 +164,47 @@ def decode_video_frames_torchvision(
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")

# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
# Note that at this point the images are in torch.uint8, in [0, 255], channel-first.
if to_pytorch_format:
# Return as pytorch format: float32, normalized to [0,1], channel-first.
closest_frames = closest_frames.type(torch.float32) / 255
else:
# Return in numpy format: np.uint8, in [0, 255], channel-last.
closest_frames = einops.rearrange(closest_frames.numpy(), "... c h w -> ... h w c")

assert len(timestamps) == len(closest_frames)
return closest_frames


@contextmanager
def _ffmpeg_log_level_to_svt(ffmpeg_log_level: str | None):
"""
Context manager to run ffmpeg command and cascade the ffmpeg log level to the corresponding SVT log level.
"""
if ffmpeg_log_level is not None:
prior_svt_log_value = os.environ.get("SVT_LOG")
# Mapping of ffmpeg log level corresponding SVT log levels
svt_log_levels = {
"quiet": 0, # SVT_LOG_FATAL
"panic": 0, # SVT_LOG_FATAL
"fatal": 0, # SVT_LOG_FATAL
"error": 1, # SVT_LOG_ERROR
"warning": 2, # SVT_LOG_WARM
"info": 3, # SVT_LOG_INFO
"verbose": 3, # SVT_LOG_INFO
"debug": 4, # SVT_LOG_DEBUG
"trace": -1, # SVT_LOG_ALL
}
os.environ["SVT_LOG"] = str(svt_log_levels[ffmpeg_log_level])
yield
if prior_svt_log_value is not None:
os.environ["SVT_LOG"] = prior_svt_log_value
else:
os.environ.pop("SVT_LOG")
else:
yield


def encode_video_frames(
imgs_dir: Path,
video_path: Path,
Expand All @@ -175,6 +218,7 @@ def encode_video_frames(
overwrite: bool = False,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""

video_path = Path(video_path)
video_path.parent.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -207,8 +251,10 @@ def encode_video_frames(
ffmpeg_args.append("-y")

ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)

with _ffmpeg_log_level_to_svt(log_level):
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)

if not video_path.exists():
raise OSError(
Expand Down
38 changes: 24 additions & 14 deletions lerobot/common/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from torch import Tensor


def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
def preprocess_observation(
observations: dict[str, np.ndarray], to_pytorch_format: bool = True
) -> dict[str, Tensor | np.ndarray]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
observations: Dictionary of observation batches from a Gym vector environment.
to_pytorch_format: Whether to return tensors instead of numpy arrays. For image observations, this
also implies switching to channel-first, float32 normalized to the range [0, 1].
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
Dictionary of observation batches in LeRobot format,
"""
# map to expected inputs for the policy
return_observations = {}
Expand All @@ -35,28 +39,34 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
imgs = {"observation.image": observations["pixels"]}

for imgkey, img in imgs.items():
img = torch.from_numpy(img)

# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"

# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
assert img.dtype == np.uint8, f"expected np.uint8, but instead got {img.dtype=}"

# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
if to_pytorch_format:
img = torch.from_numpy(img)
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255

return_observations[imgkey] = img

if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy(
observations["environment_state"]
).float()
return_observations["observation.environment_state"] = observations["environment_state"].astype(
np.float32
)
if to_pytorch_format:
return_observations["observation.environment_state"] = torch.from_numpy(
return_observations["observation.environment_state"]
)

# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return_observations["observation.state"] = observations["agent_pos"].astype(np.float32)
if to_pytorch_format:
return_observations["observation.state"] = torch.from_numpy(return_observations["observation.state"])
return return_observations
61 changes: 34 additions & 27 deletions lerobot/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import threading
import time
from contextlib import nullcontext
from copy import deepcopy
from datetime import datetime as dt
from pathlib import Path
from typing import Callable
Expand Down Expand Up @@ -145,9 +144,9 @@ def rollout(
)
while not np.all(done):
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
if return_observations:
all_observations.append(deepcopy(observation))
all_observations.append(preprocess_observation(observation, to_pytorch_format=False))
observation = preprocess_observation(observation)

observation = {key: observation[key].to(device, non_blocking=True) for key in observation}

Expand All @@ -173,34 +172,32 @@ def rollout(
# Keep track of which environments are done so far.
done = terminated | truncated | done

all_actions.append(torch.from_numpy(action))
all_rewards.append(torch.from_numpy(reward))
all_dones.append(torch.from_numpy(done))
all_successes.append(torch.tensor(successes))
all_actions.append(action)
all_rewards.append(reward)
all_dones.append(done)
all_successes.append(np.array(successes))

step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
)
running_success_rate = einops.reduce(np.stack(all_successes, axis=1), "b n -> b", "any").mean()
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()

# Track the final observation.
if return_observations:
observation = preprocess_observation(observation)
all_observations.append(deepcopy(observation))
all_observations.append(preprocess_observation(observation, to_pytorch_format=False))

# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
# All arrays go to float32 for more efficient storage in a DataBuffer.
ret = {
"action": torch.stack(all_actions, dim=1),
"reward": torch.stack(all_rewards, dim=1),
"success": torch.stack(all_successes, dim=1),
"done": torch.stack(all_dones, dim=1),
"action": np.stack(all_actions, axis=1).astype(np.float32),
"reward": np.stack(all_rewards, axis=1).astype(np.float32),
"success": np.stack(all_successes, axis=1).astype(bool),
"done": np.stack(all_dones, axis=1).astype(bool),
}
if return_observations:
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
stacked_observations[key] = np.stack([obs[key] for obs in all_observations], axis=1)
ret["observation"] = stacked_observations

return ret
Expand Down Expand Up @@ -292,11 +289,11 @@ def render_frame(env: gym.vector.VectorEnv):
# this won't be included).
n_steps = rollout_data["done"].shape[1]
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
done_indices = np.argmax(rollout_data["done"], axis=1)

# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
mask = np.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
# Extend metrics.
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
sum_rewards.extend(batch_sum_rewards.tolist())
Expand Down Expand Up @@ -416,17 +413,27 @@ def _compile_episode_data(
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"episode_index": np.array([start_episode_index + ep_ix] * (num_frames - 1)),
"frame_index": np.arange(0, num_frames - 1, 1),
"timestamp": np.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1],
}

# For the last observation frame, all other keys will just be copy padded.
# For the last observation frame, all other keys will be padded.
for k in ep_dict:
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
if k not in ["timestamp", "frame_index"]:
# Copy-pad.
ep_dict[k] = np.concatenate([ep_dict[k], ep_dict[k][-1:]])
elif k == "timestamp":
# Pad with + 1 / fps
ep_dict[k] = np.append(ep_dict[k], ep_dict[k][-1] + 1 / fps)
elif k == "frame_index":
# Pad with the next index.
ep_dict[k] = np.append(ep_dict[k], ep_dict[k][-1] + 1)
else:
raise AssertionError

for key in rollout_data["observation"]:
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
Expand All @@ -435,9 +442,9 @@ def _compile_episode_data(

data_dict = {}
for key in ep_dicts[0]:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict[key] = np.concatenate([x[key] for x in ep_dicts])

data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
data_dict["index"] = np.arange(start_data_index, start_data_index + total_frames, 1)

return data_dict

Expand Down
Loading
Loading