Skip to content

Commit

Permalink
make a separate function for vlm predicate classifier evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
lf-zhao committed May 1, 2024
1 parent 1c82c44 commit eeb1583
Showing 1 changed file with 94 additions and 86 deletions.
180 changes: 94 additions & 86 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def percept_predicates(self) -> Set[Predicate]:
def action_space(self) -> Box:
# The action space is effectively empty because only the extra info
# part of actions are used.
return Box(0, 1, (0, ))
return Box(0, 1, (0,))

@abc.abstractmethod
def _get_dry_task(self, train_or_test: str,
Expand Down Expand Up @@ -346,7 +346,7 @@ def _get_next_dry_observation(
nonpercept_atoms)

if action_name in [
"MoveToReachObject", "MoveToReadySweep", "MoveToBodyViewObject"
"MoveToReachObject", "MoveToReadySweep", "MoveToBodyViewObject"
]:
robot_rel_se2_pose = action_args[1]
return _dry_simulate_move_to_reach_obj(obs, robot_rel_se2_pose,
Expand Down Expand Up @@ -713,7 +713,7 @@ def _build_realworld_observation(
for swept_object in swept_objects:
if swept_object not in all_objects_in_view:
if container is not None and container in \
all_objects_in_view:
all_objects_in_view:
while True:
msg = (
f"\nATTENTION! The {swept_object.name} was not "
Expand Down Expand Up @@ -998,7 +998,7 @@ def _actively_construct_initial_object_views(
return obj_to_se3_pose

def _run_init_search_for_objects(
self, detection_ids: Set[ObjectDetectionID]
self, detection_ids: Set[ObjectDetectionID]
) -> Dict[ObjectDetectionID, math_helpers.SE3Pose]:
"""Have the hand look down from high up at first."""
assert self._robot is not None
Expand Down Expand Up @@ -1066,10 +1066,39 @@ def _generate_goal_description(self) -> GoalDescription:

# Provide some visual examples when needed
vlm_predicate_eval_prompt_example = ""

# TODO: Next, try include visual hints via segmentation ("Set of Masks")


def vlm_predicate_classify(question: str, state: State) -> bool:
"""Use VLM to evaluate (classify) a predicate in a given state."""
full_prompt = vlm_predicate_eval_prompt_prefix.format(
question=question
)
images_dict: Dict[str, RGBDImageWithContext] = state.camera_images
images = [PIL.Image.fromarray(v.rotated_rgb) for _, v in images_dict.items()]

logging.info(f"VLM predicate evaluation for: {question}")
logging.info(f"Prompt: {full_prompt}")

vlm_responses = vlm.sample_completions(
prompt=full_prompt,
imgs=images,
temperature=0.2,
seed=int(time.time()),
num_completions=1,
)
logging.info(f"VLM response 0: {vlm_responses[0]}")

vlm_response = vlm_responses[0].strip().lower()
if vlm_response == "yes":
return True
elif vlm_response == "no":
return False
else:
logging.error(f"VLM response not understood: {vlm_response}. Treat as False.")
return False


###############################################################################
# Shared Types, Predicates, Operators #
###############################################################################
Expand Down Expand Up @@ -1133,8 +1162,8 @@ def _object_in_xy_classifier(state: State,

spot, = state.get_objects(_robot_type)
if obj1.is_instance(_movable_object_type) and \
_is_placeable_classifier(state, [obj1]) and \
_holding_classifier(state, [spot, obj1]):
_is_placeable_classifier(state, [obj1]) and \
_holding_classifier(state, [spot, obj1]):
return False

# Check that the center of the object is contained within the surface in
Expand All @@ -1150,8 +1179,8 @@ def _object_in_xy_classifier(state: State,

def _on_classifier(state: State, objects: Sequence[Object]) -> bool:
obj_on, obj_surface = objects
currently_visible = all([o in state.visible_objects for o in objects])

currently_visible = all([o in state.visible_objects for o in objects])
# If object not all visible and choose to use VLM,
# then use predicate values of previous time step
if CFG.spot_vlm_eval_predicate and not currently_visible:
Expand All @@ -1160,37 +1189,11 @@ def _on_classifier(state: State, objects: Sequence[Object]) -> bool:

# Call VLM to evaluate predicate value
elif CFG.spot_vlm_eval_predicate and currently_visible:
predicate_str = f"On({obj_on}, {obj_surface})"
full_prompt = vlm_predicate_eval_prompt_prefix.format(
question=predicate_str
)

images_dict: Dict[str, RGBDImageWithContext] = state.camera_images
images = [PIL.Image.fromarray(v.rotated_rgb) for _, v in images_dict.items()]

# Logging: prompt
logging.info(f"VLM predicate evaluation for: {predicate_str}")
logging.info(f"Prompt: {full_prompt}")

vlm_responses = vlm.sample_completions(
prompt=full_prompt,
imgs=images,
temperature=0.2,
seed=int(time.time()),
num_completions=1,
)

# Logging
logging.info(f"VLM response 0: {vlm_responses[0]}")

vlm_response = vlm_responses[0].strip().lower()
if vlm_response == "yes":
return True
elif vlm_response == "no":
return False
else:
logging.error(f"VLM response not understood: {vlm_response}. Treat as False.")
return False
predicate_str = f"""
On({obj_on}, {obj_surface})
(Whether {obj_on} is on {obj_surface} in the image?)
"""
return vlm_predicate_classify(predicate_str, state)

else:
# Check that the bottom of the object is close to the top of the surface.
Expand All @@ -1217,53 +1220,43 @@ def _top_above_classifier(state: State, objects: Sequence[Object]) -> bool:

def _inside_classifier(state: State, objects: Sequence[Object]) -> bool:
obj_in, obj_container = objects

currently_visible = all([o in state.visible_objects for o in objects])
# If object not all visible and choose to use VLM,
# then use predicate values of previous time step
if CFG.spot_vlm_eval_predicate and not currently_visible:
# TODO: add all previous atoms to the state
raise NotImplementedError

print(currently_visible, state)

# if CFG.spot_vlm_eval_predicate and not currently_visible:
# # TODO: add all previous atoms to the state
# # TODO: then we just use the atom value from the last state
# raise NotImplementedError
# elif CFG.spot_vlm_eval_predicate and currently_visible:
# # TODO call VLM to evaluate predicate value
# full_prompt = vlm_predicate_eval_prompt_prefix.format(
# question=f"Inside({obj_in}, {obj_container})"
# )
# images = state.camera_images
#
# vlm_responses = vlm.sample_completions(
# prompt=full_prompt,
# imgs=images,
# temperature=0.2,
# seed=int(time.time()),
# num_completions=1,
# )
# vlm_response = vlm_responses[0].strip().lower()
# raise NotImplementedError
#
# else:

if not _object_in_xy_classifier(
state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER):
return False
# Call VLM to evaluate predicate value
elif CFG.spot_vlm_eval_predicate and currently_visible:
predicate_str = f"""
Inside({obj_in}, {obj_container})
(Whether {obj_in} is inside {obj_container} in the image?)
"""
return vlm_predicate_classify(predicate_str, state)

obj_z = state.get(obj_in, "z")
obj_half_height = state.get(obj_in, "height") / 2
obj_bottom = obj_z - obj_half_height
obj_top = obj_z + obj_half_height
else:
if not _object_in_xy_classifier(
state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER):
return False

container_z = state.get(obj_container, "z")
container_half_height = state.get(obj_container, "height") / 2
container_bottom = container_z - container_half_height
container_top = container_z + container_half_height
obj_z = state.get(obj_in, "z")
obj_half_height = state.get(obj_in, "height") / 2
obj_bottom = obj_z - obj_half_height
obj_top = obj_z + obj_half_height

# Check that the bottom is "above" the bottom of the container.
if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD:
return False
container_z = state.get(obj_container, "z")
container_half_height = state.get(obj_container, "height") / 2
container_bottom = container_z - container_half_height
container_top = container_z + container_half_height

# Check that the top is "below" the top of the container.
return obj_top < container_top + _INSIDE_Z_THRESHOLD
# Check that the bottom is "above" the bottom of the container.
if obj_bottom < container_bottom - _INSIDE_Z_THRESHOLD:
return False

# Check that the top is "below" the top of the container.
return obj_top < container_top + _INSIDE_Z_THRESHOLD


def _not_inside_any_container_classifier(state: State,
Expand Down Expand Up @@ -1312,8 +1305,8 @@ def in_general_view_classifier(state: State,
def _obj_reachable_from_spot_pose(spot_pose: math_helpers.SE3Pose,
obj_position: math_helpers.Vec3) -> bool:
is_xy_near = np.sqrt(
(spot_pose.x - obj_position.x)**2 +
(spot_pose.y - obj_position.y)**2) <= _REACHABLE_THRESHOLD
(spot_pose.x - obj_position.x) ** 2 +
(spot_pose.y - obj_position.y) ** 2) <= _REACHABLE_THRESHOLD

# Compute angle between spot's forward direction and the line from
# spot to the object.
Expand Down Expand Up @@ -1355,6 +1348,21 @@ def _blocking_classifier(state: State, objects: Sequence[Object]) -> bool:
if blocker_obj == blocked_obj:
return False

currently_visible = all([o in state.visible_objects for o in objects])
# If object not all visible and choose to use VLM,
# then use predicate values of previous time step
if CFG.spot_vlm_eval_predicate and not currently_visible:
# TODO: add all previous atoms to the state
raise NotImplementedError

# Call VLM to evaluate predicate value
elif CFG.spot_vlm_eval_predicate and currently_visible:
predicate_str = f"""
(Whether {blocker_obj} is blocking {blocked_obj} for further manipulation in the image?)
Blocking({blocker_obj}, {blocked_obj})
"""
return vlm_predicate_classify(predicate_str, state)

# Only consider draggable (non-placeable, movable) objects to be blockers.
if not blocker_obj.is_instance(_movable_object_type):
return False
Expand All @@ -1369,7 +1377,7 @@ def _blocking_classifier(state: State, objects: Sequence[Object]) -> bool:

spot, = state.get_objects(_robot_type)
if blocked_obj.is_instance(_movable_object_type) and \
_holding_classifier(state, [spot, blocked_obj]):
_holding_classifier(state, [spot, blocked_obj]):
return False

# Draw a line between blocked and the robot’s current pose.
Expand Down Expand Up @@ -1439,8 +1447,8 @@ def _container_adjacent_to_surface_for_sweeping(container: Object,
container_x = state.get(container, "x")
container_y = state.get(container, "y")

dist = np.sqrt((expected_x - container_x)**2 +
(expected_y - container_y)**2)
dist = np.sqrt((expected_x - container_x) ** 2 +
(expected_y - container_y) ** 2)

return dist <= _CONTAINER_SWEEP_READY_BUFFER

Expand Down Expand Up @@ -2389,7 +2397,7 @@ def _dry_simulate_sweep_into_container(
x = container_pose.x + dx
y = container_pose.y + dy
z = container_pose.z
dist_to_container = (dx**2 + dy**2)**0.5
dist_to_container = (dx ** 2 + dy ** 2) ** 0.5
assert dist_to_container > (container_radius +
_INSIDE_SURFACE_BUFFER)

Expand Down

0 comments on commit eeb1583

Please sign in to comment.