Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Naokiyokoyama/reality experiments #3

Merged
merged 5 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"torch >= 1.13.1",
"habitat-sim @ git+https://github.com/facebookresearch/habitat-sim.git",
"habitat-baselines >= 0.2.4",
"habitat-lab",
# "habitat-sim @ git+https://github.com/facebookresearch/habitat-sim.git",
# "habitat-baselines >= 0.2.4",
# "habitat-lab",
"frontier_exploration @ git+https://github.com/naokiyokoyama/frontier_exploration.git",
"transformers == 4.28.0", # higher versions break BLIP-2
"flask >= 2.3.2"
"flask >= 2.3.2",
"gym >= 0.26.2"
]

[project.optional-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion scripts/launch_vlm_servers.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ tmux send-keys -t vlm_servers:0.0 "${OS_PYTHON} -m zsos.vlm.grounding_dino ; sle

# Split the second pane horizontally and run ${OS_PYTHON} -m zsos.vlm.blip2 in the new pane
tmux split-window -h -t vlm_servers:0
tmux send-keys -t vlm_servers:0.1 "${OS_PYTHON} -m zsos.vlm.blip2 ; sleep 30" C-m
#tmux send-keys -t vlm_servers:0.1 "${OS_PYTHON} -m zsos.vlm.blip2 ; sleep 30" C-m
tmux send-keys -t vlm_servers:0.1 "${OS_PYTHON} -m zsos.vlm.blip2itm ; sleep 30" C-m

# Select the third pane and run ${OS_PYTHON} -m zsos.vlm.fiber
tmux select-pane -t vlm_servers:0.2
Expand Down
2 changes: 1 addition & 1 deletion test/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from habitat_baselines.common.baseline_registry import baseline_registry # noqa

from zsos import get_config
from zsos.run import get_config


def test_load_and_save_config():
Expand Down
4 changes: 0 additions & 4 deletions zsos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
import frontier_exploration
from habitat import get_config

import zsos.obs_transformers.resize
76 changes: 76 additions & 0 deletions zsos/mapping/frontier_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import List, Tuple

import numpy as np

from zsos.vlm.blip2itm import BLIP2ITMClient


class Frontier:
def __init__(self, xyz: np.ndarray, cosine: float):
self.xyz = xyz
self.cosine = cosine


class FrontierMap:
frontiers: List[Frontier] = []

def __init__(self, encoding_type: str = "cosine"):
self.encoder: BLIP2ITMClient = BLIP2ITMClient()

def reset(self):
self.frontiers = []

def update(
self, frontier_locations: List[np.ndarray], curr_image: np.ndarray, text: str
):
"""
Takes in a list of frontier coordinates and the current image observation from
the robot. Any stored frontiers that are not present in the given list are
removed. Any frontiers in the given list that are not already stored are added.
When these frontiers are added, their cosine field is set to the encoding
of the given image. The image will only be encoded if a new frontier is added.

Args:
frontier_locations (List[np.ndarray]): A list of frontier coordinates.
curr_image (np.ndarray): The current image observation from the robot.
text (str): The text to compare the image to.
"""
# Remove any frontiers that are not in the given list. Use np.array_equal.
self.frontiers = [
frontier
for frontier in self.frontiers
if any(
np.array_equal(frontier.xyz, location)
for location in frontier_locations
)
]

# Add any frontiers that are not already stored. Set their image field to the
# given image.
cosine = None
for location in frontier_locations:
if not any(
np.array_equal(frontier.xyz, location) for frontier in self.frontiers
):
if cosine is None:
cosine = self._encode(curr_image, text)
self.frontiers.append(Frontier(location, cosine))

def _encode(self, image: np.ndarray, text: str) -> float:
"""
Encodes the given image using the encoding type specified in the constructor.

Args:
image (np.ndarray): The image to encode.

Returns:

"""
return self.encoder.cosine(image, text)

def get_best_frontier(self) -> Tuple[np.ndarray, float]:
"""
Returns the frontier with the highest cosine and the value of that cosine.
"""
best_frontier = max(self.frontiers, key=lambda frontier: frontier.cosine)
return best_frontier.xyz, best_frontier.cosine
47 changes: 2 additions & 45 deletions zsos/obs_transformers/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
)
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig
from torch import Tensor

from zsos.obs_transformers.utils import image_resize


@baseline_registry.register_obs_transformer()
Expand Down Expand Up @@ -89,50 +90,6 @@ def from_config(cls, config: "DictConfig"):
)


def image_resize(
img: Tensor,
size: Tuple[int, int],
channels_last: bool = False,
interpolation_mode="area",
) -> torch.Tensor:
"""Resizes an img.

Args:
img: the array object that needs to be resized (HWC) or (NHWC)
size: the size that you want
channels: a boolean that channel is the last dimension
Returns:
The resized array as a torch tensor.
"""
img = torch.as_tensor(img)
no_batch_dim = len(img.shape) == 3
if len(img.shape) < 3 or len(img.shape) > 5:
raise NotImplementedError()
if no_batch_dim:
img = img.unsqueeze(0) # Adds a batch dimension
if channels_last:
if len(img.shape) == 4:
# NHWC -> NCHW
img = img.permute(0, 3, 1, 2)
else:
# NDHWC -> NDCHW
img = img.permute(0, 1, 4, 2, 3)

img = torch.nn.functional.interpolate(
img.float(), size=size, mode=interpolation_mode
).to(dtype=img.dtype)
if channels_last:
if len(img.shape) == 4:
# NCHW -> NHWC
img = img.permute(0, 2, 3, 1)
else:
# NDCHW -> NDHWC
img = img.permute(0, 1, 3, 4, 2)
if no_batch_dim:
img = img.squeeze(dim=0) # Removes the batch dimension
return img


@dataclass
class ResizeConfig(ObsTransformConfig):
type: str = Resize.__name__
Expand Down
48 changes: 48 additions & 0 deletions zsos/obs_transformers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Tuple

import torch
from torch import Tensor


def image_resize(
img: Tensor,
size: Tuple[int, int],
channels_last: bool = False,
interpolation_mode="area",
) -> torch.Tensor:
"""Resizes an img.

Args:
img: the array object that needs to be resized (HWC) or (NHWC)
size: the size that you want
channels: a boolean that channel is the last dimension
Returns:
The resized array as a torch tensor.
"""
img = torch.as_tensor(img)
no_batch_dim = len(img.shape) == 3
if len(img.shape) < 3 or len(img.shape) > 5:
raise NotImplementedError()
if no_batch_dim:
img = img.unsqueeze(0) # Adds a batch dimension
if channels_last:
if len(img.shape) == 4:
# NHWC -> NCHW
img = img.permute(0, 3, 1, 2)
else:
# NDHWC -> NDCHW
img = img.permute(0, 1, 4, 2, 3)

img = torch.nn.functional.interpolate(
img.float(), size=size, mode=interpolation_mode
).to(dtype=img.dtype)
if channels_last:
if len(img.shape) == 4:
# NCHW -> NHWC
img = img.permute(0, 2, 3, 1)
else:
# NDCHW -> NDHWC
img = img.permute(0, 1, 3, 4, 2)
if no_batch_dim:
img = img.squeeze(dim=0) # Removes the batch dimension
return img
7 changes: 6 additions & 1 deletion zsos/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

@baseline_registry.register_policy
class BasePolicy(Policy):
"""The bare minimum needed to load a policy for evaluation using ppo_trainer.py"""

def __init__(self, *args, **kwargs):
super().__init__()

Expand All @@ -34,6 +36,7 @@ def act(
masks,
deterministic=False,
):
# Just moves forwards
num_envs = observations["rgb"].shape[0]
action = torch.ones(num_envs, 1, dtype=torch.long)
return PolicyActionData(actions=action, rnn_hidden_states=rnn_hidden_states)
Expand All @@ -51,7 +54,9 @@ def parameters(self):


if __name__ == "__main__":
# Save a dummy state_dict using torch.save
# Save a dummy state_dict using torch.save. This is useful for generating a pth file
# that can be used to load other policies that don't even read from checkpoints,
# even though habitat requires a checkpoint to be loaded.
config = get_config(
"habitat-lab/habitat-baselines/habitat_baselines/config/pointnav/ppo_pointnav_example.yaml"
)
Expand Down
46 changes: 46 additions & 0 deletions zsos/policy/itm_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
from typing import Tuple

from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.tensor_dict import TensorDict
from torch import Tensor

from zsos.llm.llm import BaseLLM
from zsos.mapping.frontier_map import FrontierMap
from zsos.policy.semantic_policy import SemanticPolicy
from zsos.vlm.blip2itm import BLIP2ITMClient


@baseline_registry.register_policy
class ITMPolicy(SemanticPolicy):
llm: BaseLLM = None
visualize: bool = True
current_best_object: str = ""
depth_image_shape: Tuple[int, int] = (244, 224)
camera_height: float = 0.88
det_conf_threshold: float = 0.5
pointnav_stop_radius: float = 0.65

def __init__(self, *args, **kwargs):
super().__init__()
# VL models
self.itm = BLIP2ITMClient()
self.frontier_map: FrontierMap = FrontierMap()

def _reset(self):
super()._reset()
self.frontier_map.reset()

def _explore(self, observations: TensorDict) -> Tensor:
frontiers = observations["frontier_sensor"][0].cpu().numpy()
rgb = observations["rgb"][0].cpu().numpy()
text = f"Seems like there is a {self.target_object} ahead."
self.frontier_map.update(frontiers, rgb, text)
goal, cosine = self.frontier_map.get_best_frontier()
os.environ["DEBUG_INFO"] = f"Best frontier: {cosine}"
print(f"Step: {self.num_steps} Best frontier: {cosine}")
pointnav_action = self._pointnav(
observations, goal[:2], deterministic=True, stop=False
)

return pointnav_action
Loading
Loading