From eeb1583ac9fb9e2f9334f9b3a3ba9eebf279de8a Mon Sep 17 00:00:00 2001 From: Linfeng Date: Wed, 1 May 2024 15:44:58 -0400 Subject: [PATCH] make a separate function for vlm predicate classifier evaluation --- predicators/envs/spot_env.py | 180 ++++++++++++++++++----------------- 1 file changed, 94 insertions(+), 86 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 7b69f13fff..98c431bb22 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -308,7 +308,7 @@ def percept_predicates(self) -> Set[Predicate]: def action_space(self) -> Box: # The action space is effectively empty because only the extra info # part of actions are used. - return Box(0, 1, (0, )) + return Box(0, 1, (0,)) @abc.abstractmethod def _get_dry_task(self, train_or_test: str, @@ -346,7 +346,7 @@ def _get_next_dry_observation( nonpercept_atoms) if action_name in [ - "MoveToReachObject", "MoveToReadySweep", "MoveToBodyViewObject" + "MoveToReachObject", "MoveToReadySweep", "MoveToBodyViewObject" ]: robot_rel_se2_pose = action_args[1] return _dry_simulate_move_to_reach_obj(obs, robot_rel_se2_pose, @@ -713,7 +713,7 @@ def _build_realworld_observation( for swept_object in swept_objects: if swept_object not in all_objects_in_view: if container is not None and container in \ - all_objects_in_view: + all_objects_in_view: while True: msg = ( f"\nATTENTION! The {swept_object.name} was not " @@ -998,7 +998,7 @@ def _actively_construct_initial_object_views( return obj_to_se3_pose def _run_init_search_for_objects( - self, detection_ids: Set[ObjectDetectionID] + self, detection_ids: Set[ObjectDetectionID] ) -> Dict[ObjectDetectionID, math_helpers.SE3Pose]: """Have the hand look down from high up at first.""" assert self._robot is not None @@ -1066,10 +1066,39 @@ def _generate_goal_description(self) -> GoalDescription: # Provide some visual examples when needed vlm_predicate_eval_prompt_example = "" - # TODO: Next, try include visual hints via segmentation ("Set of Masks") +def vlm_predicate_classify(question: str, state: State) -> bool: + """Use VLM to evaluate (classify) a predicate in a given state.""" + full_prompt = vlm_predicate_eval_prompt_prefix.format( + question=question + ) + images_dict: Dict[str, RGBDImageWithContext] = state.camera_images + images = [PIL.Image.fromarray(v.rotated_rgb) for _, v in images_dict.items()] + + logging.info(f"VLM predicate evaluation for: {question}") + logging.info(f"Prompt: {full_prompt}") + + vlm_responses = vlm.sample_completions( + prompt=full_prompt, + imgs=images, + temperature=0.2, + seed=int(time.time()), + num_completions=1, + ) + logging.info(f"VLM response 0: {vlm_responses[0]}") + + vlm_response = vlm_responses[0].strip().lower() + if vlm_response == "yes": + return True + elif vlm_response == "no": + return False + else: + logging.error(f"VLM response not understood: {vlm_response}. Treat as False.") + return False + + ############################################################################### # Shared Types, Predicates, Operators # ############################################################################### @@ -1133,8 +1162,8 @@ def _object_in_xy_classifier(state: State, spot, = state.get_objects(_robot_type) if obj1.is_instance(_movable_object_type) and \ - _is_placeable_classifier(state, [obj1]) and \ - _holding_classifier(state, [spot, obj1]): + _is_placeable_classifier(state, [obj1]) and \ + _holding_classifier(state, [spot, obj1]): return False # Check that the center of the object is contained within the surface in @@ -1150,8 +1179,8 @@ def _object_in_xy_classifier(state: State, def _on_classifier(state: State, objects: Sequence[Object]) -> bool: obj_on, obj_surface = objects - currently_visible = all([o in state.visible_objects for o in objects]) + currently_visible = all([o in state.visible_objects for o in objects]) # If object not all visible and choose to use VLM, # then use predicate values of previous time step if CFG.spot_vlm_eval_predicate and not currently_visible: @@ -1160,37 +1189,11 @@ def _on_classifier(state: State, objects: Sequence[Object]) -> bool: # Call VLM to evaluate predicate value elif CFG.spot_vlm_eval_predicate and currently_visible: - predicate_str = f"On({obj_on}, {obj_surface})" - full_prompt = vlm_predicate_eval_prompt_prefix.format( - question=predicate_str - ) - - images_dict: Dict[str, RGBDImageWithContext] = state.camera_images - images = [PIL.Image.fromarray(v.rotated_rgb) for _, v in images_dict.items()] - - # Logging: prompt - logging.info(f"VLM predicate evaluation for: {predicate_str}") - logging.info(f"Prompt: {full_prompt}") - - vlm_responses = vlm.sample_completions( - prompt=full_prompt, - imgs=images, - temperature=0.2, - seed=int(time.time()), - num_completions=1, - ) - - # Logging - logging.info(f"VLM response 0: {vlm_responses[0]}") - - vlm_response = vlm_responses[0].strip().lower() - if vlm_response == "yes": - return True - elif vlm_response == "no": - return False - else: - logging.error(f"VLM response not understood: {vlm_response}. Treat as False.") - return False + predicate_str = f""" + On({obj_on}, {obj_surface}) + (Whether {obj_on} is on {obj_surface} in the image?) + """ + return vlm_predicate_classify(predicate_str, state) else: # Check that the bottom of the object is close to the top of the surface. @@ -1217,53 +1220,43 @@ def _top_above_classifier(state: State, objects: Sequence[Object]) -> bool: def _inside_classifier(state: State, objects: Sequence[Object]) -> bool: obj_in, obj_container = objects + currently_visible = all([o in state.visible_objects for o in objects]) + # If object not all visible and choose to use VLM, + # then use predicate values of previous time step + if CFG.spot_vlm_eval_predicate and not currently_visible: + # TODO: add all previous atoms to the state + raise NotImplementedError - print(currently_visible, state) - - # if CFG.spot_vlm_eval_predicate and not currently_visible: - # # TODO: add all previous atoms to the state - # # TODO: then we just use the atom value from the last state - # raise NotImplementedError - # elif CFG.spot_vlm_eval_predicate and currently_visible: - # # TODO call VLM to evaluate predicate value - # full_prompt = vlm_predicate_eval_prompt_prefix.format( - # question=f"Inside({obj_in}, {obj_container})" - # ) - # images = state.camera_images - # - # vlm_responses = vlm.sample_completions( - # prompt=full_prompt, - # imgs=images, - # temperature=0.2, - # seed=int(time.time()), - # num_completions=1, - # ) - # vlm_response = vlm_responses[0].strip().lower() - # raise NotImplementedError - # - # else: - - if not _object_in_xy_classifier( - state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER): - return False + # Call VLM to evaluate predicate value + elif CFG.spot_vlm_eval_predicate and currently_visible: + predicate_str = f""" + Inside({obj_in}, {obj_container}) + (Whether {obj_in} is inside {obj_container} in the image?) + """ + return vlm_predicate_classify(predicate_str, state) - obj_z = state.get(obj_in, "z") - obj_half_height = state.get(obj_in, "height") / 2 - obj_bottom = obj_z - obj_half_height - obj_top = obj_z + obj_half_height + else: + if not _object_in_xy_classifier( + state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER): + return False - container_z = state.get(obj_container, "z") - container_half_height = state.get(obj_container, "height") / 2 - container_bottom = container_z - container_half_height - container_top = container_z + container_half_height + obj_z = state.get(obj_in, "z") + obj_half_height = state.get(obj_in, "height") / 2 + obj_bottom = obj_z - obj_half_height + obj_top = obj_z + obj_half_height - # Check that the bottom is "above" the bottom of the container. - if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD: - return False + container_z = state.get(obj_container, "z") + container_half_height = state.get(obj_container, "height") / 2 + container_bottom = container_z - container_half_height + container_top = container_z + container_half_height - # Check that the top is "below" the top of the container. - return obj_top < container_top + _INSIDE_Z_THRESHOLD + # Check that the bottom is "above" the bottom of the container. + if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD: + return False + + # Check that the top is "below" the top of the container. + return obj_top < container_top + _INSIDE_Z_THRESHOLD def _not_inside_any_container_classifier(state: State, @@ -1312,8 +1305,8 @@ def in_general_view_classifier(state: State, def _obj_reachable_from_spot_pose(spot_pose: math_helpers.SE3Pose, obj_position: math_helpers.Vec3) -> bool: is_xy_near = np.sqrt( - (spot_pose.x - obj_position.x)**2 + - (spot_pose.y - obj_position.y)**2) <= _REACHABLE_THRESHOLD + (spot_pose.x - obj_position.x) ** 2 + + (spot_pose.y - obj_position.y) ** 2) <= _REACHABLE_THRESHOLD # Compute angle between spot's forward direction and the line from # spot to the object. @@ -1355,6 +1348,21 @@ def _blocking_classifier(state: State, objects: Sequence[Object]) -> bool: if blocker_obj == blocked_obj: return False + currently_visible = all([o in state.visible_objects for o in objects]) + # If object not all visible and choose to use VLM, + # then use predicate values of previous time step + if CFG.spot_vlm_eval_predicate and not currently_visible: + # TODO: add all previous atoms to the state + raise NotImplementedError + + # Call VLM to evaluate predicate value + elif CFG.spot_vlm_eval_predicate and currently_visible: + predicate_str = f""" + (Whether {blocker_obj} is blocking {blocked_obj} for further manipulation in the image?) + Blocking({blocker_obj}, {blocked_obj}) + """ + return vlm_predicate_classify(predicate_str, state) + # Only consider draggable (non-placeable, movable) objects to be blockers. if not blocker_obj.is_instance(_movable_object_type): return False @@ -1369,7 +1377,7 @@ def _blocking_classifier(state: State, objects: Sequence[Object]) -> bool: spot, = state.get_objects(_robot_type) if blocked_obj.is_instance(_movable_object_type) and \ - _holding_classifier(state, [spot, blocked_obj]): + _holding_classifier(state, [spot, blocked_obj]): return False # Draw a line between blocked and the robot’s current pose. @@ -1439,8 +1447,8 @@ def _container_adjacent_to_surface_for_sweeping(container: Object, container_x = state.get(container, "x") container_y = state.get(container, "y") - dist = np.sqrt((expected_x - container_x)**2 + - (expected_y - container_y)**2) + dist = np.sqrt((expected_x - container_x) ** 2 + + (expected_y - container_y) ** 2) return dist <= _CONTAINER_SWEEP_READY_BUFFER @@ -2389,7 +2397,7 @@ def _dry_simulate_sweep_into_container( x = container_pose.x + dx y = container_pose.y + dy z = container_pose.z - dist_to_container = (dx**2 + dy**2)**0.5 + dist_to_container = (dx ** 2 + dy ** 2) ** 0.5 assert dist_to_container > (container_radius + _INSIDE_SURFACE_BUFFER)