Skip to content

Commit

Permalink
added separate class and eval script for oracle FBE policy, miscellan…
Browse files Browse the repository at this point in the history
…eous cleaning of various files
  • Loading branch information
naokiyokoyamabd committed Aug 2, 2023
1 parent 3f72c6c commit 3201580
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 48 deletions.
13 changes: 0 additions & 13 deletions scripts/eval_fe_policy.sh

This file was deleted.

13 changes: 13 additions & 0 deletions scripts/eval_oracle_fbe_policy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env bash
# Copyright [2023] Boston Dynamics AI Institute, Inc.

python -um zsos.run \
habitat_baselines.evaluate=True \
habitat_baselines.eval_ckpt_path_dir=dummy_policy.pth \
habitat_baselines.load_resume_state_config=False \
habitat_baselines.rl.policy.name=OracleFBEPolicy \
habitat.task.lab_sensors.base_explorer.turn_angle=30 \
habitat_baselines.num_environments=1 \
habitat_baselines.eval.split=val_50 \
habitat.simulator.habitat_sim_v0.allow_sliding=True \
habitat_baselines.eval.video_option='["disk"]'
12 changes: 1 addition & 11 deletions zsos/policy/itm_policy.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
import os
from typing import Tuple

from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.tensor_dict import TensorDict
from torch import Tensor

from zsos.llm.llm import BaseLLM
from zsos.mapping.frontier_map import FrontierMap
from zsos.policy.semantic_policy import SemanticPolicy
from zsos.vlm.blip2itm import BLIP2ITMClient


@baseline_registry.register_policy
class ITMPolicy(SemanticPolicy):
llm: BaseLLM = None
visualize: bool = True
current_best_object: str = ""
depth_image_shape: Tuple[int, int] = (244, 224)
camera_height: float = 0.88
det_conf_threshold: float = 0.5
pointnav_stop_radius: float = 0.65

def __init__(self, *args, **kwargs):
super().__init__()
# VL models
Expand All @@ -37,7 +27,7 @@ def _explore(self, observations: TensorDict) -> Tensor:
text = f"Seems like there is a {self.target_object} ahead."
self.frontier_map.update(frontiers, rgb, text)
goal, cosine = self.frontier_map.get_best_frontier()
os.environ["DEBUG_INFO"] = f"Best frontier: {cosine}"
os.environ["DEBUG_INFO"] = f"Best frontier: {cosine:.3f}"
print(f"Step: {self.num_steps} Best frontier: {cosine}")
pointnav_action = self._pointnav(
observations, goal[:2], deterministic=True, stop=False
Expand Down
13 changes: 13 additions & 0 deletions zsos/policy/oracle_fbe_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.tensor_dict import TensorDict
from torch import Tensor

from frontier_exploration.base_explorer import BaseExplorer
from zsos.policy.semantic_policy import SemanticPolicy


@baseline_registry.register_policy
class OracleFBEPolicy(SemanticPolicy):
def _explore(self, observations: TensorDict) -> Tensor:
pointnav_action = observations[BaseExplorer.cls_uuid]
return pointnav_action
81 changes: 58 additions & 23 deletions zsos/policy/semantic_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,35 @@

import numpy as np
import torch
from habitat.tasks.nav.object_nav_task import ObjectGoalSensor
from habitat_baselines.common.tensor_dict import TensorDict
from habitat_baselines.rl.ppo.policy import PolicyActionData
from torch import Tensor

from zsos.mapping.object_map import ObjectMap
from zsos.obs_transformers.utils import image_resize
from zsos.policy.base_policy import BasePolicy
from zsos.policy.utils.pointnav_policy import (
WrappedPointNavResNetPolicy,
rho_theta_from_gps_compass_goal,
)
from zsos.vlm.grounding_dino import GroundingDINOClient, ObjectDetections

try:
from habitat_baselines.rl.ppo.policy import PolicyActionData

from zsos.policy.base_policy import BasePolicy

HABITAT_BASELINES = True
except ModuleNotFoundError:

class BasePolicy:
pass

HABITAT_BASELINES = False


ID_TO_NAME = ["chair", "bed", "potted plant", "toilet", "tv", "couch"]
ID_TO_PADDING = {
"bed": 0.2,
"couch": 0.15,
}


class TorchActionIDs:
Expand All @@ -31,8 +45,9 @@ class SemanticPolicy(BasePolicy):
target_object: str = ""
camera_height: float = 0.88
depth_image_shape: Tuple[int, int] = (244, 224)
det_conf_threshold: float = 0.35
pointnav_stop_radius: float = 0.9
det_conf_threshold: float = 0.50
pointnav_stop_radius: float = 0.85
visualize: bool = True

def __init__(self, *args, **kwargs):
super().__init__()
Expand All @@ -58,7 +73,7 @@ def _reset(self):

def act(
self, observations, rnn_hidden_states, prev_actions, masks, deterministic=False
) -> PolicyActionData:
) -> Union["PolicyActionData", Tensor]:
"""
Starts the episode by 'initializing' and allowing robot to get its bearings
(e.g., spinning in place to get a good view of the scene).
Expand All @@ -69,7 +84,13 @@ def act(
assert masks.shape[1] == 1, "Currently only supporting one env at a time"
if masks[0] == 0:
self._reset()
self.target_object = ID_TO_NAME[observations[ObjectGoalSensor.cls_uuid][0]]
object_goal = observations["objectgoal"][0].item()
if isinstance(object_goal, str):
self.target_object = object_goal
elif isinstance(object_goal, int):
self.target_object = ID_TO_NAME[object_goal]
else:
raise ValueError("Invalid object goal")

detections = self._update_object_map(observations)
goal = self._get_target_object_location()
Expand All @@ -82,21 +103,24 @@ def act(
)
else:
pointnav_action = self._explore(observations)
action_data = PolicyActionData(
actions=pointnav_action,
rnn_hidden_states=rnn_hidden_states,
policy_info=self._get_policy_info(observations, detections),
)

self.num_steps += 1

return action_data
if HABITAT_BASELINES:
action_data = PolicyActionData(
actions=pointnav_action,
rnn_hidden_states=rnn_hidden_states,
policy_info=self._get_policy_info(observations, detections),
)

return action_data
else:
return pointnav_action # just return the action

def _initialize(self) -> Tensor:
self.done_initializing = not self.num_steps < 11
return TorchActionIDs.TURN_LEFT

def _explore(self, observations: TensorDict) -> Tensor:
def _explore(self, observations: "TensorDict") -> Tensor: # noqa: F821
raise NotImplementedError

def _get_target_object_location(self) -> Union[None, np.ndarray]:
Expand All @@ -108,7 +132,7 @@ def _get_target_object_location(self) -> Union[None, np.ndarray]:

def _get_policy_info(
self,
observations: TensorDict,
observations: "TensorDict", # noqa: F821
detections: ObjectDetections,
) -> List[Dict]:
policy_info = []
Expand Down Expand Up @@ -144,11 +168,19 @@ def _get_object_detections(self, img: np.ndarray) -> ObjectDetections:

def _pointnav(
self,
observations: TensorDict,
observations: "TensorDict", # noqa: F821
goal: np.ndarray,
deterministic=False,
stop=False,
) -> Tensor:
"""
Calculates rho and theta from the robot's current position to the goal using the
gps and heading sensors within the observations and the given goal, then uses
it to determine the next action to take using the pre-trained pointnav policy.
Args:
observations ("TensorDict"): The observations from the current timestep.
"""
masks = torch.tensor([self.num_steps != 0], dtype=torch.bool, device="cuda")
if not np.array_equal(goal, self.last_goal):
self.last_goal = goal
Expand All @@ -164,21 +196,24 @@ def _pointnav(
),
"pointgoal_with_gps_compass": rho_theta.unsqueeze(0),
}
if rho_theta[0] < self.pointnav_stop_radius and stop:
stop_dist = self.pointnav_stop_radius + ID_TO_PADDING.get(
self.target_object, 0.0
)
if rho_theta[0] < stop_dist and stop:
return TorchActionIDs.STOP
action = self.pointnav_policy.act(
obs_pointnav, masks, deterministic=deterministic
)
return action

def _update_object_map(self, observations: TensorDict) -> ObjectDetections:
def _update_object_map(
self, observations: "TensorDict" # noqa: F821
) -> ObjectDetections:
"""
Updates the object map with the detections from the current timestep.
Args:
observations (TensorDict): The observations from the current timestep.
detections (ObjectDetections): The detections from the current
timestep.
observations ("TensorDict"): The observations from the current timestep.
"""
rgb = observations["rgb"][0].cpu().numpy()
depth = observations["depth"][0].cpu().numpy()
Expand Down
7 changes: 6 additions & 1 deletion zsos/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import frontier_exploration # noqa
from habitat import get_config # noqa
import zsos.obs_transformers.resize # noqa: F401
from zsos.policy import base_policy, itm_policy, llm_policy # noqa: F401
from zsos.policy import ( # noqa: F401
base_policy,
itm_policy,
llm_policy,
oracle_fbe_policy,
)


@hydra.main(
Expand Down

0 comments on commit 3201580

Please sign in to comment.