Skip to content

Commit

Permalink
working PointNav in the real world
Browse files Browse the repository at this point in the history
  • Loading branch information
naokiyokoyamabd committed Aug 2, 2023
1 parent d928793 commit 01ae4e3
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 23 deletions.
36 changes: 29 additions & 7 deletions zsos/policy/utils/non_habitat_policy/nh_pointnav_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict, Optional, Tuple

import torch
import torch.functional as F
import torch.nn as nn
import torch.nn.functional as F
from torch import Size

from .resnet import resnet18
Expand All @@ -15,7 +15,7 @@ class ResNetEncoder(nn.Module):
def __init__(self):
super().__init__()
self.running_mean_and_var = nn.Sequential()
self.backbone = resnet18(1, 32, 32)
self.backbone = resnet18(1, 32, 16)
self.compression = nn.Sequential(
nn.Conv2d(
256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
Expand Down Expand Up @@ -121,7 +121,7 @@ def forward(self, x: torch.Tensor) -> CustomNormal:

mu = torch.tanh(mu)

std = torch.clamp(std, self.min_std, self.max_std)
std = torch.clamp(std, self.min_log_std, self.max_log_std)
std = torch.exp(std)

return CustomNormal(mu, std, validate_args=False)
Expand Down Expand Up @@ -162,8 +162,30 @@ def act(
args = parser.parse_args()

ckpt = torch.load(args.state_dict_path, map_location="cpu")
model = PointNavResNetPolicy()
print(model)
current_state_dict = model.state_dict()
model.load_state_dict({k: v for k, v in ckpt.items() if k in current_state_dict})
policy = PointNavResNetPolicy()
print(policy)
current_state_dict = policy.state_dict()
policy.load_state_dict({k: v for k, v in ckpt.items() if k in current_state_dict})
print("Loaded model from checkpoint successfully!")

policy = policy.to(torch.device("cuda"))
print("Successfully moved model to GPU!")

observations = {
"depth": torch.ones(1, 212, 240, 1, device=torch.device("cuda")),
"pointgoal_with_gps_compass": torch.zeros(1, 2, device=torch.device("cuda")),
}
mask = torch.zeros(1, 1, device=torch.device("cuda"), dtype=torch.bool)

rnn_state = torch.zeros(1, 4, 512, device=torch.device("cuda"), dtype=torch.float32)

action = policy.act(
observations,
rnn_state,
torch.zeros(1, 2, device=torch.device("cuda"), dtype=torch.float32),
mask,
deterministic=True,
)

print("Forward pass successful!")
print(action[0].detach().cpu().numpy())
79 changes: 63 additions & 16 deletions zsos/policy/utils/pointnav_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,35 @@ class WrappedPointNavResNetPolicy:
and previous action for the policy.
"""

def __init__(self, ckpt_path: str):
def __init__(
self,
ckpt_path: str,
device: Union[str, torch.device] = "cuda",
discrete_actions: bool = True,
):
if isinstance(device, str):
device = torch.device(device)
self.policy = load_pointnav_policy(ckpt_path)
self.policy.to(torch.device("cuda"))
self.policy.to(device)
self.pointnav_test_recurrent_hidden_states = torch.zeros(
1, # The number of environments.
self.policy.net.num_recurrent_layers,
512, # hidden state size
device=torch.device("cuda"),
device=device,
)
if discrete_actions:
num_actions = 1
action_dtype = torch.long
else:
num_actions = 2
action_dtype = torch.float32
self.pointnav_prev_actions = torch.zeros(
1, # The number of environments.
1, # The number of actions.
device=torch.device("cuda"),
dtype=torch.long,
1, # number of environments
num_actions,
device=device,
dtype=action_dtype,
)
self.device = device

def act(
self,
Expand All @@ -68,6 +82,19 @@ def act(
Tensor (torch.dtype.long): A tensor denoting the action to take:
(0: STOP, 1: FWD, 2: LEFT, 3: RIGHT).
"""
# Convert numpy arrays to torch tensors for each dict value
for k, v in observations.items():
if isinstance(v, np.ndarray):
observations[k] = torch.from_numpy(v).to(
device=self.device, dtype=torch.float32
)
if k == "depth" and len(observations[k].shape) == 3:
observations[k] = observations[k].unsqueeze(0)
elif (
k == "pointgoal_with_gps_compass"
and len(observations[k].shape) == 1
):
observations[k] = observations[k].unsqueeze(0)
pointnav_action = self.policy.act(
observations,
self.pointnav_test_recurrent_hidden_states,
Expand All @@ -76,10 +103,16 @@ def act(
deterministic=deterministic,
)

self.pointnav_test_recurrent_hidden_states = pointnav_action.rnn_hidden_states
self.pointnav_prev_actions = pointnav_action.actions.clone()

return pointnav_action.actions
if HABITAT_BASELINES_AVAILABLE:
self.pointnav_prev_actions = pointnav_action.actions.clone()
self.pointnav_test_recurrent_hidden_states = (
pointnav_action.rnn_hidden_states
)
return pointnav_action.actions
else:
self.pointnav_prev_actions = pointnav_action[0].clone()
self.pointnav_test_recurrent_hidden_states = pointnav_action[1]
return pointnav_action[0]

def reset(self) -> None:
"""
Expand All @@ -92,7 +125,9 @@ def reset(self) -> None:


def rho_theta_from_gps_compass_goal(
observations: "TensorDict", goal: np.ndarray # noqa: F821
observations: "TensorDict", # noqa: F821
goal: np.ndarray,
device: Union[str, torch.device] = "cuda", # noqa: F821
) -> Tensor:
"""
Calculates polar coordinates (rho, theta) relative to the agent's current position
Expand All @@ -109,6 +144,7 @@ def rho_theta_from_gps_compass_goal(
the agent must turn to the left (CCW from above) from its initial heading to
reach its current heading.
goal (np.ndarray): Array of shape (2,) representing the goal position.
device (Union[str, torch.device]): The device to use for the tensor.
Returns:
Tensor: A tensor of shape (2,) representing the polar coordinates (rho, theta).
Expand All @@ -120,9 +156,7 @@ def rho_theta_from_gps_compass_goal(
heading = observations["compass"].squeeze(1).cpu().numpy()[0]
gps_numpy[1] *= -1 # Flip y-axis to match habitat's coordinate system.
rho, theta = rho_theta(gps_numpy, heading, goal)
rho_theta_tensor = torch.tensor(
[rho, theta], device=torch.device("cuda"), dtype=torch.float32
)
rho_theta_tensor = torch.tensor([rho, theta], device=device, dtype=torch.float32)

return rho_theta_tensor

Expand Down Expand Up @@ -233,4 +267,17 @@ def wrap_heading(theta: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
args = parser.parse_args()

policy = load_pointnav_policy(args.ckpt_path)
print(policy)
print("Loaded model from checkpoint successfully!")
mask = torch.zeros(1, 1, device=torch.device("cuda"), dtype=torch.bool)
observations = {
"depth": torch.zeros(1, 224, 224, 1, device=torch.device("cuda")),
"pointgoal_with_gps_compass": torch.zeros(1, 2, device=torch.device("cuda")),
}
policy.to(torch.device("cuda"))
action = policy.act(
observations,
torch.zeros(1, 4, 512, device=torch.device("cuda"), dtype=torch.float32),
torch.zeros(1, 1, device=torch.device("cuda"), dtype=torch.long),
mask,
)
print("Forward pass successful!")
49 changes: 49 additions & 0 deletions zsos/reality/bdsw_nav_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import torch
from spot_wrapper.spot import Spot

from zsos.policy.utils.pointnav_policy import WrappedPointNavResNetPolicy
from zsos.reality.pointnav_env import PointNavEnv
from zsos.reality.robots.bdsw_robot import BDSWRobot


def run_env(env: PointNavEnv, policy: WrappedPointNavResNetPolicy, goal: np.ndarray):
observations = env.reset(goal)
done = False
mask = torch.zeros(1, 1, device=policy.device, dtype=torch.bool)
action = policy.act(observations, mask)
while not done:
observations, _, done, info = env.step(action)
action = policy.act(observations, mask, deterministic=True)
mask = torch.ones_like(mask)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"pointnav_ckpt_path",
type=str,
default="pointnav_resnet_18.pth",
help="Path to the pointnav model checkpoint",
)
parser.add_argument(
"-g",
"--goal",
type=str,
default="3.5,0.0",
help="Goal location in the form x,y",
)
args = parser.parse_args()
pointnav_ckpt_path = args.pointnav_ckpt_path
policy = WrappedPointNavResNetPolicy(pointnav_ckpt_path, discrete_actions=False)
goal = np.array([float(x) for x in args.goal.split(",")])

spot = Spot("BDSW_env") # just a name, can be anything
with spot.get_lease(): # turns the robot on, and off upon any errors or completion
spot.power_on()
spot.blocking_stand()
robot = BDSWRobot(spot)
env = PointNavEnv(robot)
run_env(env, policy, goal)
108 changes: 108 additions & 0 deletions zsos/reality/pointnav_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import time
from typing import Dict, Tuple, Union

import cv2
import numpy as np
import torch
from depth_camera_filtering import filter_depth

from zsos.mapping.object_map import convert_to_global_frame
from zsos.policy.utils.pointnav_policy import rho_theta
from zsos.reality.robots.base_robot import BaseRobot
from zsos.reality.robots.camera_ids import SpotCamIds


class PointNavEnv:
"""Gym environment for doing the PointNav task."""

max_depth: float = 3.5
success_radius: float = 0.425
goal: np.ndarray = np.array([0.0, 0.0])
max_lin_dist: float = 0.25
max_ang_dist: float = np.deg2rad(30)
time_step: float = 0.5
depth_shape: Tuple[int, int] = (212, 240) # height, width
info: Dict = {}

def __init__(self, robot: BaseRobot):
self.robot = robot

def reset(self, goal: np.ndarray, relative=True) -> Dict[str, np.ndarray]:
if relative:
# Transform (x,y) goal from robot frame to global frame
pos, yaw = self.robot.xy_yaw
pos_w_z = np.array([pos[0], pos[1], 0.0]) # inject dummy z value
goal_w_z = np.array([goal[0], goal[1], 0.0]) # inject dummy z value
goal = convert_to_global_frame(pos_w_z, yaw, goal_w_z)[:2] # drop z
self.goal = goal
return self._get_obs()

def step(
self, action: Union[np.ndarray, torch.Tensor]
) -> Tuple[Dict, float, bool, Dict]:
self.info = {}
if isinstance(action, torch.Tensor):
action = action.detach().cpu().numpy()
ang_vel, lin_vel = self._compute_velocities(action)
self.robot.command_base_velocity(ang_vel, lin_vel)
time.sleep(self.time_step)
self.robot.command_base_velocity(0.0, 0.0)
r_t = self._get_rho_theta()
print("rho: ", r_t[0], "theta: ", np.rad2deg(r_t[1]))
return self._get_obs(), 0.0, self.done, self.info

@property
def done(self) -> bool:
rho = self._get_rho_theta()[0]
return rho < self.success_radius

def _compute_velocities(self, action: np.ndarray) -> Tuple[float, float]:
ang_dist, lin_dist = np.clip(
action[0],
-1.0,
1.0,
)
ang_dist *= self.max_ang_dist
lin_dist *= self.max_lin_dist
ang_vel = ang_dist / self.time_step
lin_vel = lin_dist / self.time_step
print("action: ", action[0])
print("ang_vel: ", np.rad2deg(ang_vel), "lin_vel: ", lin_vel)
print("ang_dist: ", np.rad2deg(ang_dist), "lin_dist: ", lin_dist)
return ang_vel, lin_vel

def _get_obs(self) -> Dict[str, np.ndarray]:
return {
"depth": self._get_depth(),
"pointgoal_with_gps_compass": self._get_rho_theta(),
}

def _get_depth(self) -> np.ndarray:
images = self.robot.get_camera_images(
[SpotCamIds.FRONTRIGHT_DEPTH, SpotCamIds.FRONTLEFT_DEPTH]
)
# Spot is cross-eyed, so right eye is on the left, and vice versa
img = np.hstack(
[images[SpotCamIds.FRONTRIGHT_DEPTH], images[SpotCamIds.FRONTLEFT_DEPTH]]
)
img = img.astype(np.float32) / 1000.0 # Convert to meters from mm (uint16)
# Filter the image and re-scale based on max depth limit (self.max_depth)
img = filter_depth(
img, clip_far_thresh=self.max_depth, set_black_value=self.max_depth
)
img = img / self.max_depth # Normalize to [0, 1]
# Down-sample to policy input shape
img = cv2.resize(
img,
(self.depth_shape[1], self.depth_shape[0]),
interpolation=cv2.INTER_AREA,
)
# Add a channel dimension
img = img.reshape(img.shape + (1,))

return img

def _get_rho_theta(self) -> np.ndarray:
curr_pos, yaw = self.robot.xy_yaw
r_t = rho_theta(curr_pos, yaw, self.goal)
return np.array(r_t)
Loading

0 comments on commit 01ae4e3

Please sign in to comment.