Skip to content

Commit

Permalink
wrapping video generation in try except, adding confidence fusion abl…
Browse files Browse the repository at this point in the history
…ations
  • Loading branch information
naokiyokoyamabd committed Sep 12, 2023
1 parent f830958 commit 1231c78
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
22 changes: 22 additions & 0 deletions zsos/mapping/value_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
value_channels: int,
size: int = 1000,
use_max_confidence: bool = True,
fusion_type: str = "default",
):
"""
Args:
Expand All @@ -55,6 +56,7 @@ def __init__(
self._value_map = np.zeros((size, size, value_channels), np.float32)
self._value_channels = value_channels
self._use_max_confidence = use_max_confidence
self._fusion_type = fusion_type

if RECORDING:
if osp.isdir(RECORDING_DIR):
Expand Down Expand Up @@ -356,6 +358,26 @@ def _fuse_new_data(self, new_map: np.ndarray, values: np.ndarray) -> None:
f"({len(values)}). Expected {self._value_channels}."
)

if self._fusion_type == "replace":
# Ablation. The values from the current observation will overwrite any
# existing values
print("VALUE MAP ABLATION:", self._fusion_type)
new_value_map = np.zeros_like(self._value_map)
new_value_map[new_map > 0] = values
self._map[new_map > 0] = new_map[new_map > 0]
self._value_map[new_map > 0] = new_value_map[new_map > 0]
return
elif self._fusion_type == "equal_weighting":
# Ablation. Updated values will always be the mean of the current and
# new values, meaning that confidence scores are forced to be the same.
print("VALUE MAP ABLATION:", self._fusion_type)
self._map[self._map > 0] = 1
new_map[new_map > 0] = 1
else:
assert (
self._fusion_type == "default"
), f"Unknown fusion type {self._fusion_type}"

# Any values in the given map that are less confident than
# self._decision_threshold AND less than the new_map in the existing map
# will be silenced into 0s
Expand Down
5 changes: 4 additions & 1 deletion zsos/semexp_env/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def main():
"target_object": target_object,
}
if "VIDEO_DIR" in os.environ:
generate_video(vis_imgs, ep_id, scene_id, data)
try:
generate_video(vis_imgs, ep_id, scene_id, data)
except Exception:
print("Error generating video")
if "ZSOS_LOG_DIR" in os.environ and not is_evaluated(ep_id, scene_id):
log_episode(ep_id, scene_id, data)
break
Expand Down

0 comments on commit 1231c78

Please sign in to comment.