From fe7da3016916a5dc9b57ea98d29edc2bb3bb49b5 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Mon, 16 Sep 2024 15:57:13 -0400 Subject: [PATCH] Debugging progressm. --- .../approaches/bilevel_planning_approach.py | 4 +--- predicators/envs/spot_env.py | 2 -- predicators/perception/spot_perceiver.py | 6 +++--- predicators/utils.py | 17 +++++++++++------ 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 5cadda34c..fc45a70de 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -58,9 +58,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: seed = self._seed + self._num_calls nsrts = self._get_current_nsrts() preds = self._get_current_predicates() - utils.abstract(task.init, preds, self._vlm) - import pdb - pdb.set_trace() + # utils.abstract(task.init, preds, self._vlm) # utils.abstract(task.init, preds, self._vlm) # Run task planning only and then greedily sample and execute in the # policy. diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 05e6eaf16..e84aa7ae7 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -106,10 +106,8 @@ class _TruncatedSpotObservation: # # A placeholder until all predicates have classifiers # nonpercept_atoms: Set[GroundAtom] # nonpercept_predicates: Set[Predicate] - executed_skill: Optional[_Option] = None # Object detections per camera in self.rgbd_images. object_detections_per_camera: Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]] - # Last skill executed_skill: Optional[_Option] = None diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 62394889b..201df39f2 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -684,7 +684,7 @@ def reset(self, env_task: EnvironmentTask) -> Task: return Task(state, goal) def step(self, observation: Observation) -> State: - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() self._waiting_for_observation = False self._robot = observation.robot img_objects = observation.rgbd_images # RGBDImage objects @@ -714,7 +714,7 @@ def step(self, observation: Observation) -> State: draw.rectangle(text_bbox, fill='green') draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font) - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() import PIL from PIL import ImageDraw annotated_pil_imgs = [] @@ -730,8 +730,8 @@ def step(self, observation: Observation) -> State: self._curr_state = self._create_state() self._curr_state.simulator_state["images"] = annotated_imgs ret_state = self._curr_state.copy() - ret_state.simulator_state["state_history"] = list(self._state_history) self._state_history.append(ret_state) + ret_state.simulator_state["state_history"] = list(self._state_history) self._executed_skill_history.append(observation.executed_skill) ret_state.simulator_state["skill_history"] = list(self._executed_skill_history) return ret_state diff --git a/predicators/utils.py b/predicators/utils.py index b13201b3c..f4e32a3b7 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2499,6 +2499,7 @@ def get_prompt_for_vlm_state_labelling( imgs_history: List[List[PIL.Image.Image]], cropped_imgs_history: List[List[PIL.Image.Image]], skill_history: List[Action]) -> Tuple[str, List[PIL.Image.Image]]: + # import pdb; pdb.set_trace() """Prompt for generating labels for an entire trajectory. Similar to the above prompting method, this outputs a list of prompts to label the state at each timestep of traj with atom values). @@ -2508,7 +2509,7 @@ def get_prompt_for_vlm_state_labelling( """ # Load the pre-specified prompt. filepath_prefix = get_path_to_predicators_root() + \ - "/predicators/datasets/vlm_input_data_prompts/atom_proposal/" + "/predicators/datasets/vlm_input_data_prompts/atom_labelling/" try: with open(filepath_prefix + CFG.grammar_search_vlm_atom_label_prompt_type + ".txt", @@ -2516,6 +2517,7 @@ def get_prompt_for_vlm_state_labelling( encoding="utf-8") as f: prompt = f.read() except FileNotFoundError: + import pdb; pdb.set_trace() raise ValueError("Unknown VLM prompting option " + f"{CFG.grammar_search_vlm_atom_label_prompt_type}") # The prompt ends with a section for 'Predicates', so list these. @@ -2583,9 +2585,9 @@ def query_vlm_for_atom_vals( state.simulator_state["images"] for state in previous_states ] vlm_atoms = sorted(vlm_atoms) - atom_queries_str = [atom.get_vlm_query_str() for atom in vlm_atoms] + atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( - CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str, + CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, state.simulator_state["vlm_atoms_history"], state_imgs_history, [], state.simulator_state["skill_history"]) if vlm is None: @@ -2600,21 +2602,24 @@ def query_vlm_for_atom_vals( assert len(vlm_output) == 1 vlm_output_str = vlm_output[0] print(f"VLM output: {vlm_output_str}") - all_atom_queries = atom_queries_str.strip().split("\n") all_vlm_responses = vlm_output_str.strip().split("\n") # NOTE: this assumption is likely too brittle; if this is breaking, feel # free to remove/adjust this and change the below parsing loop accordingly! - assert len(all_atom_queries) == len(all_vlm_responses) + assert len(atom_queries_list) == len(all_vlm_responses) for i, (atom_query, curr_vlm_output_line) in enumerate( - zip(all_atom_queries, all_vlm_responses)): + zip(atom_queries_list, all_vlm_responses)): assert atom_query + ":" in curr_vlm_output_line assert "." in curr_vlm_output_line period_idx = curr_vlm_output_line.find(".") if curr_vlm_output_line[len(atom_query + ":"):period_idx].lower().strip() == "true": true_atoms.add(vlm_atoms[i]) + + breakpoint() # Add the text of the VLM's response to the state, to be used in the future! + # REMOVE THIS -> AND PUT IT IN THE PERCEIVER state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) + return true_atoms