Skip to content

Commit

Permalink
fix everything except tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nkumar-bdai committed Mar 7, 2024
1 parent 04f4082 commit 07c0472
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 27 deletions.
25 changes: 13 additions & 12 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
# so we're going to instantiate a global variable here to indicate
# whether or not we've connected to pybullet before.

_SIMULATED_SPOT_ROBOT = None
_SIMULATED_SPOT_ROBOT: Optional[pbrspot.spot.Spot] = None
# Used to keep track of all simulated objects; also needs to be global
# since we make and reset the env multiple times!
_obj_name_to_sim_obj: Dict[str, pbrspot.body.Body] = {}
Expand Down Expand Up @@ -221,9 +221,11 @@ def __init__(self, use_gui: bool = True) -> None:
"Must use spot wrapper in spot envs!"
# If we're doing proper bilevel planning, then we need to instantiate
# a simulator!
global _SIMULATED_SPOT_ROBOT
if not CFG.bilevel_plan_without_sim:
if _SIMULATED_SPOT_ROBOT is None:
global _SIMULATED_SPOT_ROBOT # pylint:disable=global-statement
if _SIMULATED_SPOT_ROBOT is not None:
self.sim_robot = _SIMULATED_SPOT_ROBOT
else:
# First, launch pybullet.
pbrspot.utils.connect(use_gui=True)
pbrspot.utils.disable_real_time()
Expand All @@ -241,9 +243,6 @@ def __init__(self, use_gui: bool = True) -> None:
pbrspot.placements.stable_z(self.sim_robot, floor_obj)
])
_SIMULATED_SPOT_ROBOT = self.sim_robot
else:
self.sim_robot = _SIMULATED_SPOT_ROBOT

robot, localizer, lease_client = get_robot()
self._robot = robot
self._localizer = localizer
Expand Down Expand Up @@ -434,17 +433,19 @@ def reset(self, train_or_test: str, task_idx: int) -> Observation:

# Start by modifying the simulated robot to be in the right
# position and configuration.
# TODO: probably also want to reset the pybullet sim to only
# have the robot and floor as well here.
if not CFG.bilevel_plan_without_sim:
global _obj_name_to_sim_obj
global _obj_name_to_sim_obj # pylint:disable=global-statement
# Start by removing all previously-known objects from
# the simulation.
if len(_obj_name_to_sim_obj) > 0:
for sim_obj in _obj_name_to_sim_obj.values():
sim_obj.remove_body()
_obj_name_to_sim_obj = {}
# If we're connected to a real-world robot, then update the
# simulated robot to be in exactly the sasme joint
# configuration as the real robot.
if self._robot is not None:
update_pbrspot_robot_conf(self._robot, self.sim_robot)

# Find the relevant object urdfs and then put them at the
# right places in the world. Importantly note that we
# expect the name of the object to be the same as the name
Expand Down Expand Up @@ -747,7 +748,6 @@ def _get_next_nonpercept_atoms(self, obs: _SpotObservation,
}

def simulate(self, state: State, action: Action) -> State:
global _obj_name_to_sim_obj
assert isinstance(action.extra_info, (list, tuple))
action_name, action_objs, _, _, action_fn, action_fn_args = \
action.extra_info
Expand Down Expand Up @@ -897,7 +897,8 @@ def _load_task_from_json(self, json_file: Path) -> EnvironmentTask:
}
for obj, init_val in init_dict.items():
if obj in obj_to_detection_id:
init_val["object_id"] = obj_to_detection_id[obj]
init_val["object_id"] = obj_to_detection_id[
obj] # type: ignore
init_state = utils.create_state_from_dict(init_dict)
goal = self._parse_goal_from_json_dict(json_dict,
object_name_to_object,
Expand Down
2 changes: 1 addition & 1 deletion predicators/ground_truth_models/spot_env/nsrts.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def get_nsrts(env_name: str, types: Dict[str, Type],
# If we're doing proper bilevel planning with a simulator, then
# we need to replace some of the samplers.
if not CFG.bilevel_plan_without_sim:
operator_name_to_sampler["PickObjectFromTop"]: utils.null_sampler
operator_name_to_sampler["PickObjectFromTop"] = utils.null_sampler
# NOTE: will probably have to replace all other pick ops
# similarly in the future.

Expand Down
25 changes: 11 additions & 14 deletions predicators/ground_truth_models/spot_env/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,21 +323,18 @@ def _move_to_target_policy(name: str, distance_param_idx: int,
target_height = state.get(target_obj, "height")
gaze_target = math_helpers.Vec3(target_pose.x, target_pose.y,
target_pose.z + target_height / 2)

fn = navigate_to_relative_pose_and_gaze
fn_args = (robot, rel_pose, localizer, gaze_target)
sim_fn: Callable = simulated_navigate_to_relative_pose_and_gaze
sim_fn_args: Tuple = (sim_robot,
robot_pose.get_closest_se2_transform() * rel_pose,
gaze_target)
if not do_gaze:
fn: Callable = navigate_to_relative_pose
fn_args: Tuple = (robot, rel_pose)
sim_fn: Callable = simulated_navigate_to_relative_pose
sim_fn_args: Tuple = (sim_robot,
robot_pose.get_closest_se2_transform() *
rel_pose)
else:
fn = navigate_to_relative_pose_and_gaze
fn_args = (robot, rel_pose, localizer, gaze_target)
sim_fn: Callable = simulated_navigate_to_relative_pose_and_gaze
sim_fn_args: Tuple = (sim_robot,
robot_pose.get_closest_se2_transform() *
rel_pose, gaze_target)
fn = navigate_to_relative_pose # type: ignore
fn_args = (robot, rel_pose) # type: ignore
sim_fn = simulated_navigate_to_relative_pose
sim_fn_args = (sim_robot,
robot_pose.get_closest_se2_transform() * rel_pose)

return utils.create_spot_env_action(name, objects, fn, fn_args, sim_fn,
sim_fn_args)
Expand Down

0 comments on commit 07c0472

Please sign in to comment.