Skip to content

Commit

Permalink
Debugging progressm.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashay-bdai committed Sep 16, 2024
1 parent 6e25b72 commit fe7da30
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 14 deletions.
4 changes: 1 addition & 3 deletions predicators/approaches/bilevel_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
seed = self._seed + self._num_calls
nsrts = self._get_current_nsrts()
preds = self._get_current_predicates()
utils.abstract(task.init, preds, self._vlm)
import pdb
pdb.set_trace()
# utils.abstract(task.init, preds, self._vlm)
# utils.abstract(task.init, preds, self._vlm)
# Run task planning only and then greedily sample and execute in the
# policy.
Expand Down
2 changes: 0 additions & 2 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ class _TruncatedSpotObservation:
# # A placeholder until all predicates have classifiers
# nonpercept_atoms: Set[GroundAtom]
# nonpercept_predicates: Set[Predicate]
executed_skill: Optional[_Option] = None
# Object detections per camera in self.rgbd_images.
object_detections_per_camera: Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]]
# Last skill
executed_skill: Optional[_Option] = None


Expand Down
6 changes: 3 additions & 3 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def reset(self, env_task: EnvironmentTask) -> Task:
return Task(state, goal)

def step(self, observation: Observation) -> State:
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
self._waiting_for_observation = False
self._robot = observation.robot
img_objects = observation.rgbd_images # RGBDImage objects
Expand Down Expand Up @@ -714,7 +714,7 @@ def step(self, observation: Observation) -> State:
draw.rectangle(text_bbox, fill='green')
draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font)

import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
import PIL
from PIL import ImageDraw
annotated_pil_imgs = []
Expand All @@ -730,8 +730,8 @@ def step(self, observation: Observation) -> State:
self._curr_state = self._create_state()
self._curr_state.simulator_state["images"] = annotated_imgs
ret_state = self._curr_state.copy()
ret_state.simulator_state["state_history"] = list(self._state_history)
self._state_history.append(ret_state)
ret_state.simulator_state["state_history"] = list(self._state_history)
self._executed_skill_history.append(observation.executed_skill)
ret_state.simulator_state["skill_history"] = list(self._executed_skill_history)
return ret_state
Expand Down
17 changes: 11 additions & 6 deletions predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2499,6 +2499,7 @@ def get_prompt_for_vlm_state_labelling(
imgs_history: List[List[PIL.Image.Image]],
cropped_imgs_history: List[List[PIL.Image.Image]],
skill_history: List[Action]) -> Tuple[str, List[PIL.Image.Image]]:
# import pdb; pdb.set_trace()
"""Prompt for generating labels for an entire trajectory. Similar to the
above prompting method, this outputs a list of prompts to label the state
at each timestep of traj with atom values).
Expand All @@ -2508,14 +2509,15 @@ def get_prompt_for_vlm_state_labelling(
"""
# Load the pre-specified prompt.
filepath_prefix = get_path_to_predicators_root() + \
"/predicators/datasets/vlm_input_data_prompts/atom_proposal/"
"/predicators/datasets/vlm_input_data_prompts/atom_labelling/"
try:
with open(filepath_prefix +
CFG.grammar_search_vlm_atom_label_prompt_type + ".txt",
"r",
encoding="utf-8") as f:
prompt = f.read()
except FileNotFoundError:
import pdb; pdb.set_trace()
raise ValueError("Unknown VLM prompting option " +
f"{CFG.grammar_search_vlm_atom_label_prompt_type}")
# The prompt ends with a section for 'Predicates', so list these.
Expand Down Expand Up @@ -2583,9 +2585,9 @@ def query_vlm_for_atom_vals(
state.simulator_state["images"] for state in previous_states
]
vlm_atoms = sorted(vlm_atoms)
atom_queries_str = [atom.get_vlm_query_str() for atom in vlm_atoms]
atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms]
vlm_query_str, imgs = get_prompt_for_vlm_state_labelling(
CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str,
CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list,
state.simulator_state["vlm_atoms_history"], state_imgs_history, [],
state.simulator_state["skill_history"])
if vlm is None:
Expand All @@ -2600,21 +2602,24 @@ def query_vlm_for_atom_vals(
assert len(vlm_output) == 1
vlm_output_str = vlm_output[0]
print(f"VLM output: {vlm_output_str}")
all_atom_queries = atom_queries_str.strip().split("\n")
all_vlm_responses = vlm_output_str.strip().split("\n")
# NOTE: this assumption is likely too brittle; if this is breaking, feel
# free to remove/adjust this and change the below parsing loop accordingly!
assert len(all_atom_queries) == len(all_vlm_responses)
assert len(atom_queries_list) == len(all_vlm_responses)
for i, (atom_query, curr_vlm_output_line) in enumerate(
zip(all_atom_queries, all_vlm_responses)):
zip(atom_queries_list, all_vlm_responses)):
assert atom_query + ":" in curr_vlm_output_line
assert "." in curr_vlm_output_line
period_idx = curr_vlm_output_line.find(".")
if curr_vlm_output_line[len(atom_query +
":"):period_idx].lower().strip() == "true":
true_atoms.add(vlm_atoms[i])

breakpoint()
# Add the text of the VLM's response to the state, to be used in the future!
# REMOVE THIS -> AND PUT IT IN THE PERCEIVER
state.simulator_state["vlm_atoms_history"].append(all_vlm_responses)

return true_atoms


Expand Down

0 comments on commit fe7da30

Please sign in to comment.