From 01ae4e349a689b5589111dc0f1a9aa70276dbb45 Mon Sep 17 00:00:00 2001 From: Naoki Yokoyama Date: Tue, 1 Aug 2023 20:11:57 -0400 Subject: [PATCH] working PointNav in the real world --- .../non_habitat_policy/nh_pointnav_policy.py | 36 ++++-- zsos/policy/utils/pointnav_policy.py | 79 ++++++++++--- zsos/reality/bdsw_nav_env.py | 49 ++++++++ zsos/reality/pointnav_env.py | 108 ++++++++++++++++++ zsos/reality/robots/base_robot.py | 66 +++++++++++ zsos/reality/robots/bdsw_robot.py | 41 +++++++ zsos/reality/robots/camera_ids.py | 56 +++++++++ 7 files changed, 412 insertions(+), 23 deletions(-) create mode 100644 zsos/reality/bdsw_nav_env.py create mode 100644 zsos/reality/pointnav_env.py create mode 100644 zsos/reality/robots/base_robot.py create mode 100644 zsos/reality/robots/bdsw_robot.py create mode 100644 zsos/reality/robots/camera_ids.py diff --git a/zsos/policy/utils/non_habitat_policy/nh_pointnav_policy.py b/zsos/policy/utils/non_habitat_policy/nh_pointnav_policy.py index a524d74..8353e34 100644 --- a/zsos/policy/utils/non_habitat_policy/nh_pointnav_policy.py +++ b/zsos/policy/utils/non_habitat_policy/nh_pointnav_policy.py @@ -1,8 +1,8 @@ from typing import Dict, Optional, Tuple import torch -import torch.functional as F import torch.nn as nn +import torch.nn.functional as F from torch import Size from .resnet import resnet18 @@ -15,7 +15,7 @@ class ResNetEncoder(nn.Module): def __init__(self): super().__init__() self.running_mean_and_var = nn.Sequential() - self.backbone = resnet18(1, 32, 32) + self.backbone = resnet18(1, 32, 16) self.compression = nn.Sequential( nn.Conv2d( 256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False @@ -121,7 +121,7 @@ def forward(self, x: torch.Tensor) -> CustomNormal: mu = torch.tanh(mu) - std = torch.clamp(std, self.min_std, self.max_std) + std = torch.clamp(std, self.min_log_std, self.max_log_std) std = torch.exp(std) return CustomNormal(mu, std, validate_args=False) @@ -162,8 +162,30 @@ def act( args = parser.parse_args() ckpt = torch.load(args.state_dict_path, map_location="cpu") - model = PointNavResNetPolicy() - print(model) - current_state_dict = model.state_dict() - model.load_state_dict({k: v for k, v in ckpt.items() if k in current_state_dict}) + policy = PointNavResNetPolicy() + print(policy) + current_state_dict = policy.state_dict() + policy.load_state_dict({k: v for k, v in ckpt.items() if k in current_state_dict}) print("Loaded model from checkpoint successfully!") + + policy = policy.to(torch.device("cuda")) + print("Successfully moved model to GPU!") + + observations = { + "depth": torch.ones(1, 212, 240, 1, device=torch.device("cuda")), + "pointgoal_with_gps_compass": torch.zeros(1, 2, device=torch.device("cuda")), + } + mask = torch.zeros(1, 1, device=torch.device("cuda"), dtype=torch.bool) + + rnn_state = torch.zeros(1, 4, 512, device=torch.device("cuda"), dtype=torch.float32) + + action = policy.act( + observations, + rnn_state, + torch.zeros(1, 2, device=torch.device("cuda"), dtype=torch.float32), + mask, + deterministic=True, + ) + + print("Forward pass successful!") + print(action[0].detach().cpu().numpy()) diff --git a/zsos/policy/utils/pointnav_policy.py b/zsos/policy/utils/pointnav_policy.py index b348003..800beda 100644 --- a/zsos/policy/utils/pointnav_policy.py +++ b/zsos/policy/utils/pointnav_policy.py @@ -27,21 +27,35 @@ class WrappedPointNavResNetPolicy: and previous action for the policy. """ - def __init__(self, ckpt_path: str): + def __init__( + self, + ckpt_path: str, + device: Union[str, torch.device] = "cuda", + discrete_actions: bool = True, + ): + if isinstance(device, str): + device = torch.device(device) self.policy = load_pointnav_policy(ckpt_path) - self.policy.to(torch.device("cuda")) + self.policy.to(device) self.pointnav_test_recurrent_hidden_states = torch.zeros( 1, # The number of environments. self.policy.net.num_recurrent_layers, 512, # hidden state size - device=torch.device("cuda"), + device=device, ) + if discrete_actions: + num_actions = 1 + action_dtype = torch.long + else: + num_actions = 2 + action_dtype = torch.float32 self.pointnav_prev_actions = torch.zeros( - 1, # The number of environments. - 1, # The number of actions. - device=torch.device("cuda"), - dtype=torch.long, + 1, # number of environments + num_actions, + device=device, + dtype=action_dtype, ) + self.device = device def act( self, @@ -68,6 +82,19 @@ def act( Tensor (torch.dtype.long): A tensor denoting the action to take: (0: STOP, 1: FWD, 2: LEFT, 3: RIGHT). """ + # Convert numpy arrays to torch tensors for each dict value + for k, v in observations.items(): + if isinstance(v, np.ndarray): + observations[k] = torch.from_numpy(v).to( + device=self.device, dtype=torch.float32 + ) + if k == "depth" and len(observations[k].shape) == 3: + observations[k] = observations[k].unsqueeze(0) + elif ( + k == "pointgoal_with_gps_compass" + and len(observations[k].shape) == 1 + ): + observations[k] = observations[k].unsqueeze(0) pointnav_action = self.policy.act( observations, self.pointnav_test_recurrent_hidden_states, @@ -76,10 +103,16 @@ def act( deterministic=deterministic, ) - self.pointnav_test_recurrent_hidden_states = pointnav_action.rnn_hidden_states - self.pointnav_prev_actions = pointnav_action.actions.clone() - - return pointnav_action.actions + if HABITAT_BASELINES_AVAILABLE: + self.pointnav_prev_actions = pointnav_action.actions.clone() + self.pointnav_test_recurrent_hidden_states = ( + pointnav_action.rnn_hidden_states + ) + return pointnav_action.actions + else: + self.pointnav_prev_actions = pointnav_action[0].clone() + self.pointnav_test_recurrent_hidden_states = pointnav_action[1] + return pointnav_action[0] def reset(self) -> None: """ @@ -92,7 +125,9 @@ def reset(self) -> None: def rho_theta_from_gps_compass_goal( - observations: "TensorDict", goal: np.ndarray # noqa: F821 + observations: "TensorDict", # noqa: F821 + goal: np.ndarray, + device: Union[str, torch.device] = "cuda", # noqa: F821 ) -> Tensor: """ Calculates polar coordinates (rho, theta) relative to the agent's current position @@ -109,6 +144,7 @@ def rho_theta_from_gps_compass_goal( the agent must turn to the left (CCW from above) from its initial heading to reach its current heading. goal (np.ndarray): Array of shape (2,) representing the goal position. + device (Union[str, torch.device]): The device to use for the tensor. Returns: Tensor: A tensor of shape (2,) representing the polar coordinates (rho, theta). @@ -120,9 +156,7 @@ def rho_theta_from_gps_compass_goal( heading = observations["compass"].squeeze(1).cpu().numpy()[0] gps_numpy[1] *= -1 # Flip y-axis to match habitat's coordinate system. rho, theta = rho_theta(gps_numpy, heading, goal) - rho_theta_tensor = torch.tensor( - [rho, theta], device=torch.device("cuda"), dtype=torch.float32 - ) + rho_theta_tensor = torch.tensor([rho, theta], device=device, dtype=torch.float32) return rho_theta_tensor @@ -233,4 +267,17 @@ def wrap_heading(theta: Union[float, np.ndarray]) -> Union[float, np.ndarray]: args = parser.parse_args() policy = load_pointnav_policy(args.ckpt_path) - print(policy) + print("Loaded model from checkpoint successfully!") + mask = torch.zeros(1, 1, device=torch.device("cuda"), dtype=torch.bool) + observations = { + "depth": torch.zeros(1, 224, 224, 1, device=torch.device("cuda")), + "pointgoal_with_gps_compass": torch.zeros(1, 2, device=torch.device("cuda")), + } + policy.to(torch.device("cuda")) + action = policy.act( + observations, + torch.zeros(1, 4, 512, device=torch.device("cuda"), dtype=torch.float32), + torch.zeros(1, 1, device=torch.device("cuda"), dtype=torch.long), + mask, + ) + print("Forward pass successful!") diff --git a/zsos/reality/bdsw_nav_env.py b/zsos/reality/bdsw_nav_env.py new file mode 100644 index 0000000..c566c3e --- /dev/null +++ b/zsos/reality/bdsw_nav_env.py @@ -0,0 +1,49 @@ +import numpy as np +import torch +from spot_wrapper.spot import Spot + +from zsos.policy.utils.pointnav_policy import WrappedPointNavResNetPolicy +from zsos.reality.pointnav_env import PointNavEnv +from zsos.reality.robots.bdsw_robot import BDSWRobot + + +def run_env(env: PointNavEnv, policy: WrappedPointNavResNetPolicy, goal: np.ndarray): + observations = env.reset(goal) + done = False + mask = torch.zeros(1, 1, device=policy.device, dtype=torch.bool) + action = policy.act(observations, mask) + while not done: + observations, _, done, info = env.step(action) + action = policy.act(observations, mask, deterministic=True) + mask = torch.ones_like(mask) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "pointnav_ckpt_path", + type=str, + default="pointnav_resnet_18.pth", + help="Path to the pointnav model checkpoint", + ) + parser.add_argument( + "-g", + "--goal", + type=str, + default="3.5,0.0", + help="Goal location in the form x,y", + ) + args = parser.parse_args() + pointnav_ckpt_path = args.pointnav_ckpt_path + policy = WrappedPointNavResNetPolicy(pointnav_ckpt_path, discrete_actions=False) + goal = np.array([float(x) for x in args.goal.split(",")]) + + spot = Spot("BDSW_env") # just a name, can be anything + with spot.get_lease(): # turns the robot on, and off upon any errors or completion + spot.power_on() + spot.blocking_stand() + robot = BDSWRobot(spot) + env = PointNavEnv(robot) + run_env(env, policy, goal) diff --git a/zsos/reality/pointnav_env.py b/zsos/reality/pointnav_env.py new file mode 100644 index 0000000..0bb6eed --- /dev/null +++ b/zsos/reality/pointnav_env.py @@ -0,0 +1,108 @@ +import time +from typing import Dict, Tuple, Union + +import cv2 +import numpy as np +import torch +from depth_camera_filtering import filter_depth + +from zsos.mapping.object_map import convert_to_global_frame +from zsos.policy.utils.pointnav_policy import rho_theta +from zsos.reality.robots.base_robot import BaseRobot +from zsos.reality.robots.camera_ids import SpotCamIds + + +class PointNavEnv: + """Gym environment for doing the PointNav task.""" + + max_depth: float = 3.5 + success_radius: float = 0.425 + goal: np.ndarray = np.array([0.0, 0.0]) + max_lin_dist: float = 0.25 + max_ang_dist: float = np.deg2rad(30) + time_step: float = 0.5 + depth_shape: Tuple[int, int] = (212, 240) # height, width + info: Dict = {} + + def __init__(self, robot: BaseRobot): + self.robot = robot + + def reset(self, goal: np.ndarray, relative=True) -> Dict[str, np.ndarray]: + if relative: + # Transform (x,y) goal from robot frame to global frame + pos, yaw = self.robot.xy_yaw + pos_w_z = np.array([pos[0], pos[1], 0.0]) # inject dummy z value + goal_w_z = np.array([goal[0], goal[1], 0.0]) # inject dummy z value + goal = convert_to_global_frame(pos_w_z, yaw, goal_w_z)[:2] # drop z + self.goal = goal + return self._get_obs() + + def step( + self, action: Union[np.ndarray, torch.Tensor] + ) -> Tuple[Dict, float, bool, Dict]: + self.info = {} + if isinstance(action, torch.Tensor): + action = action.detach().cpu().numpy() + ang_vel, lin_vel = self._compute_velocities(action) + self.robot.command_base_velocity(ang_vel, lin_vel) + time.sleep(self.time_step) + self.robot.command_base_velocity(0.0, 0.0) + r_t = self._get_rho_theta() + print("rho: ", r_t[0], "theta: ", np.rad2deg(r_t[1])) + return self._get_obs(), 0.0, self.done, self.info + + @property + def done(self) -> bool: + rho = self._get_rho_theta()[0] + return rho < self.success_radius + + def _compute_velocities(self, action: np.ndarray) -> Tuple[float, float]: + ang_dist, lin_dist = np.clip( + action[0], + -1.0, + 1.0, + ) + ang_dist *= self.max_ang_dist + lin_dist *= self.max_lin_dist + ang_vel = ang_dist / self.time_step + lin_vel = lin_dist / self.time_step + print("action: ", action[0]) + print("ang_vel: ", np.rad2deg(ang_vel), "lin_vel: ", lin_vel) + print("ang_dist: ", np.rad2deg(ang_dist), "lin_dist: ", lin_dist) + return ang_vel, lin_vel + + def _get_obs(self) -> Dict[str, np.ndarray]: + return { + "depth": self._get_depth(), + "pointgoal_with_gps_compass": self._get_rho_theta(), + } + + def _get_depth(self) -> np.ndarray: + images = self.robot.get_camera_images( + [SpotCamIds.FRONTRIGHT_DEPTH, SpotCamIds.FRONTLEFT_DEPTH] + ) + # Spot is cross-eyed, so right eye is on the left, and vice versa + img = np.hstack( + [images[SpotCamIds.FRONTRIGHT_DEPTH], images[SpotCamIds.FRONTLEFT_DEPTH]] + ) + img = img.astype(np.float32) / 1000.0 # Convert to meters from mm (uint16) + # Filter the image and re-scale based on max depth limit (self.max_depth) + img = filter_depth( + img, clip_far_thresh=self.max_depth, set_black_value=self.max_depth + ) + img = img / self.max_depth # Normalize to [0, 1] + # Down-sample to policy input shape + img = cv2.resize( + img, + (self.depth_shape[1], self.depth_shape[0]), + interpolation=cv2.INTER_AREA, + ) + # Add a channel dimension + img = img.reshape(img.shape + (1,)) + + return img + + def _get_rho_theta(self) -> np.ndarray: + curr_pos, yaw = self.robot.xy_yaw + r_t = rho_theta(curr_pos, yaw, self.goal) + return np.array(r_t) diff --git a/zsos/reality/robots/base_robot.py b/zsos/reality/robots/base_robot.py new file mode 100644 index 0000000..2aaf36b --- /dev/null +++ b/zsos/reality/robots/base_robot.py @@ -0,0 +1,66 @@ +from typing import Dict, List + +import numpy as np + +from zsos.reality.robots.camera_ids import CAM_ID_TO_SHAPE, SHOULD_ROTATE + + +class BaseRobot: + def get_camera_images(self, camera_source: List[str]) -> Dict[str, np.ndarray]: + raise NotImplementedError + + def command_base_velocity(self, ang_vel: float, lin_vel: float): + raise NotImplementedError + + @property + def xy_yaw(self) -> np.ndarray: + """Returns x, y, yaw""" + raise NotImplementedError + + @property + def arm_joints(self) -> np.ndarray: + """Returns current angle for each of the 7 arm joints""" + raise NotImplementedError + + @staticmethod + def _reorient_images(imgs_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Rotate images if necessary. + + Args: + imgs_dict: Dictionary of images. + Returns: + Dictionary of images, rotated if necessary. + """ + for camera_id, img in imgs_dict.items(): + if camera_id in SHOULD_ROTATE: + imgs_dict[camera_id] = np.rot90(img, k=3) + return imgs_dict + + +class FakeRobot(BaseRobot): + def get_camera_images(self, camera_source: List[str]) -> Dict[str, np.ndarray]: + """ + Return a list of random images. Ensure they are the right shapes. camera_source + is a list of camera ids that are attributes of SpotCamIds, and its shape is + stored in CAM_ID_TO_SHAPE. + """ + for source in camera_source: + assert source in CAM_ID_TO_SHAPE, f"Invalid camera source: {source}" + images = { + source: np.random.rand(*CAM_ID_TO_SHAPE[source]) for source in camera_source + } + return self._reorient_images(images) + + def command_base_velocity(self, ang_vel: float, lin_vel: float): + pass + + @property + def xy_yaw(self) -> np.ndarray: + """Returns a random x, y, yaw""" + return np.random.rand(3) + + @property + def arm_joints(self) -> np.ndarray: + """Returns a random angle for each of the 7 arm joints""" + return np.random.rand(7) diff --git a/zsos/reality/robots/bdsw_robot.py b/zsos/reality/robots/bdsw_robot.py new file mode 100644 index 0000000..eaa2388 --- /dev/null +++ b/zsos/reality/robots/bdsw_robot.py @@ -0,0 +1,41 @@ +from typing import Dict, List, Tuple + +import numpy as np +from spot_wrapper.spot import Spot, image_response_to_cv2 + +from zsos.reality.robots.base_robot import BaseRobot + +MAX_CMD_DURATION = 5 + + +class BDSWRobot(BaseRobot): + def __init__(self, spot: Spot): + self.spot = spot + + def get_camera_images(self, camera_source: List[str]) -> Dict[str, np.ndarray]: + # Get Spot camera image + image_responses = self.spot.get_image_responses(camera_source) + imgs = { + source: image_response_to_cv2(image_response, reorient=True) + for source, image_response in zip(camera_source, image_responses) + } + return imgs + + def command_base_velocity(self, ang_vel: float, lin_vel: float): + self.spot.set_base_velocity( + lin_vel, + 0.0, # no horizontal velocity + ang_vel, + MAX_CMD_DURATION, + ) + + @property + def xy_yaw(self) -> Tuple[np.ndarray, float]: + robot_state = self.spot.get_robot_state() + x, y, yaw = self.spot.get_xy_yaw(robot_state=robot_state) + return np.array([x, y]), yaw + + @property + def arm_joints(self) -> np.ndarray: + """Returns a random angle for each of the 7 arm joints""" + raise NotImplementedError diff --git a/zsos/reality/robots/camera_ids.py b/zsos/reality/robots/camera_ids.py new file mode 100644 index 0000000..066960b --- /dev/null +++ b/zsos/reality/robots/camera_ids.py @@ -0,0 +1,56 @@ +class SpotCamIds: + r"""Enumeration of types of cameras.""" + + BACK_DEPTH = "back_depth" + BACK_DEPTH_IN_VISUAL_FRAME = "back_depth_in_visual_frame" + BACK_FISHEYE = "back_fisheye_image" + FRONTLEFT_DEPTH = "frontleft_depth" + FRONTLEFT_DEPTH_IN_VISUAL_FRAME = "frontleft_depth_in_visual_frame" + FRONTLEFT_FISHEYE = "frontleft_fisheye_image" + FRONTRIGHT_DEPTH = "frontright_depth" + FRONTRIGHT_DEPTH_IN_VISUAL_FRAME = "frontright_depth_in_visual_frame" + FRONTRIGHT_FISHEYE = "frontright_fisheye_image" + HAND_COLOR = "hand_color_image" + HAND_COLOR_IN_HAND_DEPTH_FRAME = "hand_color_in_hand_depth_frame" + HAND_DEPTH = "hand_depth" + HAND_DEPTH_IN_HAND_COLOR_FRAME = "hand_depth_in_hand_color_frame" + HAND = "hand_image" + LEFT_DEPTH = "left_depth" + LEFT_DEPTH_IN_VISUAL_FRAME = "left_depth_in_visual_frame" + LEFT_FISHEYE = "left_fisheye_image" + RIGHT_DEPTH = "right_depth" + RIGHT_DEPTH_IN_VISUAL_FRAME = "right_depth_in_visual_frame" + RIGHT_FISHEYE = "right_fisheye_image" + + +# CamIds that need to be rotated by 270 degrees in order to appear upright +SHOULD_ROTATE = { + SpotCamIds.FRONTLEFT_DEPTH, + SpotCamIds.FRONTRIGHT_DEPTH, + SpotCamIds.HAND_DEPTH, + SpotCamIds.HAND, +} + +# Maps camera ids to the shapes of their images +CAM_ID_TO_SHAPE = { + SpotCamIds.BACK_DEPTH: (424, 240, 1), + SpotCamIds.BACK_DEPTH_IN_VISUAL_FRAME: (640, 480, 1), + SpotCamIds.BACK_FISHEYE: (640, 480, 3), + SpotCamIds.FRONTLEFT_DEPTH: (424, 240, 1), + SpotCamIds.FRONTLEFT_DEPTH_IN_VISUAL_FRAME: (640, 480, 1), + SpotCamIds.FRONTLEFT_FISHEYE: (640, 480, 3), + SpotCamIds.FRONTRIGHT_DEPTH: (424, 240, 1), + SpotCamIds.FRONTRIGHT_DEPTH_IN_VISUAL_FRAME: (640, 480, 1), + SpotCamIds.FRONTRIGHT_FISHEYE: (640, 480, 3), + SpotCamIds.HAND_COLOR: (640, 480, 3), + SpotCamIds.HAND_COLOR_IN_HAND_DEPTH_FRAME: (640, 480, 1), + SpotCamIds.HAND_DEPTH: (224, 171, 1), + SpotCamIds.HAND_DEPTH_IN_HAND_COLOR_FRAME: (224, 171, 1), + SpotCamIds.HAND: (224, 171, 3), + SpotCamIds.LEFT_DEPTH: (424, 240, 1), + SpotCamIds.LEFT_DEPTH_IN_VISUAL_FRAME: (640, 480, 1), + SpotCamIds.LEFT_FISHEYE: (640, 480, 3), + SpotCamIds.RIGHT_DEPTH: (424, 240, 1), + SpotCamIds.RIGHT_DEPTH_IN_VISUAL_FRAME: (640, 480, 1), + SpotCamIds.RIGHT_FISHEYE: (640, 480, 3), +}