diff --git a/.gitignore b/.gitignore index 50b9875e..65ed4224 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,371 @@ -.vscode/ +# READ THIS BEFORE YOU REFACTOR ME +# +# setup.py uses the list of patterns in this file to decide +# what to delete, but it's not 100% sound. So, for example, +# if you delete aten/build/ because it's redundant with build/, +# aten/build/ will stop being cleaned. So be careful when +# refactoring this file! + +experiments/experiments_data/checkpoints/* +experiments/experiments_data/tmp/* +experiments/experiments_data/datasets/* +amg_example/checkpoints/* + +## PyTorch + +.coverage +coverage.xml +.dmypy.json +.gradle +.hypothesis +.mypy_cache +/.extracted_scripts/ +**/.pytorch_specified_test_cases.csv +**/.pytorch-disabled-tests.json +**/.pytorch-slow-tests.json +**/.pytorch-test-times.json +**/.pytorch-test-file-ratings.json +*/*.pyc +*/*.so* +*/**/__pycache__ +*/**/*.dylib* +*/**/*.pyc +*/**/*.pyd +*/**/*.so* +*/**/**/*.pyc +*/**/**/**/*.pyc +*/**/**/**/**/*.pyc +aten/build/ +aten/src/ATen/Config.h +aten/src/ATen/cuda/CUDAConfig.h +benchmarks/.data +caffe2/cpp_test/ +dist/ +docs/build/ +docs/cpp/src +docs/src/**/* +docs/cpp/build +docs/cpp/source/api +docs/cpp/source/html/ +docs/cpp/source/latex/ +docs/source/compile/generated/ +docs/source/generated/ +docs/source/compile/generated/ +log +usage_log.txt +test-reports/ +test/*.bak +test/**/*.bak +test/.coverage +test/.hypothesis/ +test/cpp/api/mnist +test/custom_operator/model.pt +test/jit_hooks/*.pt +test/data/legacy_modules.t7 +test/data/*.pt +test/forward_backward_compatibility/nightly_schemas.txt +dropout_model.pt +test/generated_type_hints_smoketest.py +test/htmlcov +test/cpp_extensions/install/ +third_party/build/ +tools/coverage_plugins_package/pip-wheel-metadata/ +tools/shared/_utils_internal.py +tools/fast_nvcc/wrap_nvcc.sh +tools/fast_nvcc/wrap_nvcc.bat +tools/fast_nvcc/tmp/ +torch.egg-info/ +torch/_C/__init__.pyi +torch/_C/_nn.pyi +torch/_C/_VariableFunctions.pyi +torch/_VF.pyi +torch/return_types.pyi +torch/nn/functional.pyi +torch/utils/data/datapipes/datapipe.pyi +torch/csrc/autograd/generated/* +torch/csrc/lazy/generated/*.[!m]* +torch_compile_debug/ +# Listed manually because some files in this directory are not generated +torch/testing/_internal/generated/annotated_fn_args.py +torch/testing/_internal/data/*.pt +torch/csrc/api/include/torch/version.h +torch/csrc/cudnn/cuDNN.cpp +torch/csrc/generated +torch/csrc/generic/TensorMethods.cpp +torch/csrc/jit/generated/* +torch/csrc/jit/fuser/config.h +torch/csrc/nn/THCUNN.cpp +torch/csrc/nn/THCUNN.cwrap +torch/bin/ +torch/cmake/ +torch/lib/*.a* +torch/lib/*.dll* +torch/lib/*.exe* +torch/lib/*.dylib* +torch/lib/*.h +torch/lib/*.lib +torch/lib/*.pdb +torch/lib/*.so* +torch/lib/protobuf*.pc +torch/lib/build +torch/lib/caffe2/ +torch/lib/cmake +torch/lib/include +torch/lib/pkgconfig +torch/lib/protoc +torch/lib/protobuf/ +torch/lib/tmp_install +torch/lib/torch_shm_manager +torch/lib/site-packages/ +torch/lib/python* +torch/lib64 +torch/include/ +torch/share/ +torch/test/ +torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h +torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h +torch/version.py +minifier_launcher.py +# Root level file used in CI to specify certain env configs. +# E.g., see .circleci/config.yaml +env +.circleci/scripts/COMMIT_MSG +scripts/release_notes/*.json +sccache-stats*.json + +# These files get copied over on invoking setup.py +torchgen/packaged/* +!torchgen/packaged/README.md + +# IPython notebook checkpoints +.ipynb_checkpoints + +# Editor temporaries +*.swa +*.swb +*.swc +*.swd +*.swe +*.swf +*.swg +*.swh +*.swi +*.swj +*.swk +*.swl +*.swm +*.swn +*.swo +*.swp +*~ +.~lock.* + +# macOS dir files .DS_Store -__pycache__/ -*-checkpoint.ipynb -.venv -*.egg* -build/* -_C.* -outputs/* -checkpoints/*.pt + +# Ninja files +.ninja_deps +.ninja_log +compile_commands.json +*.egg-info/ +docs/source/scripts/activation_images/ +docs/source/scripts/quantization_backend_configs/ + +## General + +# Compiled Object files +*.slo +*.lo +*.o +*.cuo +*.obj + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Compiled protocol buffers +*.pb.h +*.pb.cc +*_pb2.py + +# Compiled python +*.pyc +*.pyd + +# Compiled MATLAB +*.mex* + +# IPython notebook checkpoints +.ipynb_checkpoints + +# Editor temporaries +*.swn +*.swo +*.swp +*~ + +# NFS handle files +**/.nfs* + +# Sublime Text settings +*.sublime-workspace +*.sublime-project + +# Eclipse Project settings +*.*project +.settings + +# QtCreator files +*.user + +# PyCharm files +.idea + +# GDB history +.gdb_history + +## Caffe2 + +# build, distribute, and bins (+ python proto bindings) +build/ +# Allow tools/build/ for build support. +!tools/build/ +build_host_protoc +build_android +build_ios +.build_debug/* +.build_release/* +.build_profile/* +distribute/* +*.testbin +*.bin +cmake_build +.cmake_build +gen +.setuptools-cmake-build +.pytest_cache +aten/build/* + +# Bram +plsdontbreak + +# Generated documentation +docs/_site +docs/gathered +_site +doxygen +docs/dev + +# LevelDB files +*.sst +*.ldb +LOCK +CURRENT +MANIFEST-* + +# generated version file +caffe2/version.py + +# setup.py intermediates +.eggs +caffe2.egg-info +MANIFEST + +# Atom/Watchman required file +.watchmanconfig + +# Files generated by CLion +cmake-build-debug + +# BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.) +# +# Below files are not deleted by "setup.py clean". + +# Downloaded bazel +tools/bazel + +# Visual Studio Code files +.vs +/.vscode/* +!/.vscode/extensions.json +!/.vscode/settings_recommended.json + +# YouCompleteMe config file +.ycm_extra_conf.py + +# Files generated when a patch is rejected +*.orig +*.rej + +# Files generated by ctags +CTAGS +GTAGS +GRTAGS +GSYMS +GPATH +tags +TAGS + + +# ccls file +.ccls-cache/ + +# clang tooling storage location +.clang-format-bin +.clang-tidy-bin +.lintbin + +# clangd background index +.clangd/ +.cache/ + +# bazel symlinks +bazel-* + +# xla repo +xla/ + +# direnv, posh-direnv +.env +.envrc +.psenvrc + +# generated shellcheck directories +.shellcheck_generated*/ + +# zip archives +*.zip + +# core dump files +**/core.[1-9]* + +# Generated if you use the pre-commit script for clang-tidy +pr.diff + +# coverage files +*/**/.coverage.* + +# buck generated files +.buckd/ +.lsp-buck-out/ +.lsp.buckd/ +buck-out/ + +# Downloaded libraries +third_party/ruy/ +third_party/glog/ + +# Virtualenv +venv/ + +# Log files +*.log +sweep/ diff --git a/amg_example/README.md b/amg_example/README.md new file mode 100644 index 00000000..776c8539 --- /dev/null +++ b/amg_example/README.md @@ -0,0 +1,5 @@ +To run this example you need to download the vit_h checkpoint and put it into a local folder named checkpoints + +You can find the checkpoint for vit_h here: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth + +To read the image you also need to install opencv-python: https://pypi.org/project/opencv-python/ diff --git a/amg_example/amg_example.py b/amg_example/amg_example.py new file mode 100644 index 00000000..b252d514 --- /dev/null +++ b/amg_example/amg_example.py @@ -0,0 +1,119 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +import cv2 +import torch.utils.benchmark as benchmark + +def profiler_runner(path, fn, *args, **kwargs): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + record_shapes=True) as prof: + result = fn(*args, **kwargs) + print(f"Saving trace under {path}") + prof.export_chrome_trace(path) + return result + +def show_anns(anns): + if len(anns) == 0: + return + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + ax = plt.gca() + ax.set_autoscale_on(False) + + img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) + img[:,:,3] = 0 + ms = [] + for ann in sorted_anns: + m = ann['segmentation'] + ms.append(torch.as_tensor(m)) + color_mask = np.concatenate([np.random.random(3), [0.35]]) + img[m] = color_mask + ax.imshow(img) + return torch.stack(ms) + +image = cv2.imread('dog.jpg') +image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + +# from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator +# +# sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth" +# model_type = "vit_h" +device = "cuda" +# +# sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint) +# sam.to(device=device) + +from sam2.build_sam import build_sam2 +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + +sam2_checkpoint = "checkpoints/sam2_hiera_large.pt" +model_cfg = "sam2_hiera_l.yaml" + +sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) +sam2.to(device=device) + +mask_generator = SAM2AutomaticMaskGenerator(sam2) + +# ### --- +# TODO: Causes a numerical mismatch. CUDA graphs? + +# torch.set_float32_matmul_precision('high') +# torch.autocast("cuda", dtype=torch.bfloat16).__enter__() +# mask_generator.predictor.model.image_encoder = torch.compile( +# mask_generator.predictor.model.image_encoder, +# mode="max-autotune-no-cudagraphs", +# fullgraph=True, +# dynamic=False, +# ) +# +# mask_generator.predictor._predict = torch.compile( +# mask_generator.predictor._predict, +# mode="max-autotune-no-cudagraphs", +# fullgraph=True, +# dynamic=False, +# ) + +# ### --- + + +with torch.no_grad(): + # Run thrice for warmup + masks = mask_generator.generate(image) + masks = mask_generator.generate(image) + masks = mask_generator.generate(image) + + # Save an example + plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100) + plt.imshow(image) + ms = show_anns(masks) + ms_ref = torch.load("dog_mask_fast.pt") + # # TODO: USE mIoU! + torch.testing.assert_allclose(ms, ms_ref) + print("Masks match reference") + # torch.save(ms, "dog_mask_fast.pt") + plt.axis('off') + plt.tight_layout() + plt.savefig('dog_mask_fast.png', format='png') + + # Benchmark + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(10): + masks = mask_generator.generate(image) + end_event.record() + torch.cuda.synchronize() + print(start_event.elapsed_time(end_event) / 10.) + + # Save a GPU trace + profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image) + + # Write out memory usage + max_memory_allocated_bytes = torch.cuda.max_memory_allocated() + _, total_memory = torch.cuda.mem_get_info() + max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory)) + max_memory_allocated_bytes = max_memory_allocated_bytes >> 20 + print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}") diff --git a/amg_example/amg_example_trace.json.gz b/amg_example/amg_example_trace.json.gz new file mode 100644 index 00000000..a28714c1 Binary files /dev/null and b/amg_example/amg_example_trace.json.gz differ diff --git a/amg_example/dog.jpg b/amg_example/dog.jpg new file mode 100644 index 00000000..26d6454d Binary files /dev/null and b/amg_example/dog.jpg differ diff --git a/amg_example/dog_mask.png b/amg_example/dog_mask.png new file mode 100644 index 00000000..b6934911 Binary files /dev/null and b/amg_example/dog_mask.png differ diff --git a/amg_example/dog_mask_fast.png b/amg_example/dog_mask_fast.png new file mode 100644 index 00000000..3af93edd Binary files /dev/null and b/amg_example/dog_mask_fast.png differ diff --git a/amg_example/dog_mask_fast.pt b/amg_example/dog_mask_fast.pt new file mode 100644 index 00000000..2ee06174 Binary files /dev/null and b/amg_example/dog_mask_fast.pt differ diff --git a/amg_example/sam2_hiera_l.yaml b/amg_example/sam2_hiera_l.yaml new file mode 100644 index 00000000..918667f5 --- /dev/null +++ b/amg_example/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/automatic_mask_generator.py b/sam2/automatic_mask_generator.py index 065e469e..3cd044d3 100644 --- a/sam2/automatic_mask_generator.py +++ b/sam2/automatic_mask_generator.py @@ -12,6 +12,7 @@ from torchvision.ops.boxes import batched_nms, box_area # type: ignore from sam2.modeling.sam2_base import SAM2Base +from sam2.map_tensor import to_map_tensor from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.utils.amg import ( area_from_rle, @@ -24,6 +25,7 @@ generate_crop_boxes, is_box_near_crop_edge, mask_to_rle_pytorch, + mask_to_rle_pytorch_2, MaskData, remove_small_regions, rle_to_mask, @@ -267,29 +269,33 @@ def _process_crop( # Generate masks for this crop in batches data = MaskData() - for (points,) in batch_iterator(self.points_per_batch, points_for_image): - batch_data = self._process_batch( - points, cropped_im_size, crop_box, orig_size, normalize=True - ) - data.cat(batch_data) - del batch_data - self.predictor.reset_predictor() - - # Remove duplicates within this crop. - keep_by_nms = batched_nms( - data["boxes"].float(), - data["iou_preds"], - torch.zeros_like(data["boxes"][:, 0]), # categories - iou_threshold=self.box_nms_thresh, + pointss = [points for (points,) in batch_iterator(self.points_per_batch, points_for_image)] + pointss = [torch.as_tensor(points, dtype=torch.float32, device=self.predictor.device) for points in pointss] + pointss = torch.stack(pointss) + batch_datas = self._process_batches( + pointss, cropped_im_size, crop_box, orig_size, normalize=True ) - data.filter(keep_by_nms) + with torch.autograd.profiler.record_function("rest of crop 0"): + for batch_data in batch_datas: + data.cat(batch_data) + del batch_data + self.predictor.reset_predictor() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) - # Return to the original image frame - data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) - data["points"] = uncrop_points(data["points"], crop_box) - data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) - return data + return data def _process_batch( self, @@ -311,12 +317,13 @@ def _process_batch( in_labels = torch.ones( in_points.shape[0], dtype=torch.int, device=in_points.device ) - masks, iou_preds, low_res_masks = self.predictor._predict( - in_points[:, None, :], - in_labels[:, None], - multimask_output=self.multimask_output, - return_logits=True, - ) + with torch.autograd.profiler.record_function("_predict"): + masks, iou_preds, low_res_masks = self.predictor._predict( + in_points[:, None, :], + in_labels[:, None], + multimask_output=self.multimask_output, + return_logits=True, + ) # Serialize predictions and store in MaskData data = MaskData( @@ -378,11 +385,128 @@ def _process_batch( # Compress to RLE data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) - data["rles"] = mask_to_rle_pytorch(data["masks"]) + data["rles"] = mask_to_rle_pytorch_2(data["masks"]) del data["masks"] return data + def _process_batches( + self, + pointss: torch.Tensor, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + normalize=False, + ) -> List[MaskData]: + orig_h, orig_w = orig_size + + in_pointss = [] + in_labelss = [] + for points in pointss: + in_points = self.predictor._transforms.transform_coords( + points, normalize=normalize, orig_hw=im_size + ) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + in_pointss.append(in_points) + in_labelss.append(in_labels) + + in_pointss = torch.stack(in_pointss) + in_labelss = torch.stack(in_labelss) + + # maskss = [] + # iou_predss = [] + # low_res_maskss = [] + # for (in_points, in_labels) in zip(in_pointss, in_labelss): + # with torch.autograd.profiler.record_function("_predict"): + # masks, iou_preds, low_res_masks = self.predictor._predict( + # in_points[:, None, :], + # in_labels[:, None], + # multimask_output=self.multimask_output, + # return_logits=True, + # ) + # maskss.append(masks) + # iou_predss.append(iou_preds) + # low_res_maskss.append(low_res_masks) + + with torch.autograd.profiler.record_function("_predict"): + maskss, iou_predss, low_res_maskss = self.predictor._predict( + to_map_tensor(in_pointss[:, :, None, :]), + to_map_tensor(in_labelss[:, :, None]), + multimask_output=self.multimask_output, + return_logits=True, + ) + + with torch.autograd.profiler.record_function("other 0"): + datas = [] + for (masks, iou_preds, low_res_masks) in zip(maskss, iou_predss, low_res_maskss): + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=points.repeat_interleave(masks.shape[1], dim=0), + low_res_masks=low_res_masks.flatten(0, 1), + ) + del masks + + if not self.use_m2m: + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate and filter by stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + else: + # One step refinement using previous mask predictions + in_points = self.predictor._transforms.transform_coords( + data["points"], normalize=normalize, orig_hw=im_size + ) + labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, ious = self.refine_with_m2m( + in_points, labels, data["low_res_masks"], self.points_per_batch + ) + data["masks"] = masks.squeeze(1) + data["iou_preds"] = ious.squeeze(1) + + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch_2(data["masks"]) + del data["masks"] + + datas.append(data) + return datas + @staticmethod def postprocess_small_regions( mask_data: MaskData, min_area: int, nms_thresh: float diff --git a/sam2/map_tensor.py b/sam2/map_tensor.py new file mode 100644 index 00000000..d1c4fb76 --- /dev/null +++ b/sam2/map_tensor.py @@ -0,0 +1,415 @@ +import contextlib +import torch +from torch.utils._pytree import tree_map +from typing import Dict + +@contextlib.contextmanager +def no_dispatch(): + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + +def wrap_dim(i, dim): + if i < 0: + return dim + i + return i + +def unwrap(t): + if isinstance(t, MapTensor): + with no_dispatch(): + return t.elems + else: + return t + +def unwrap_i(t, i): + if isinstance(t, MapTensor): + with no_dispatch(): + return t.elems[i] + else: + return t + +def unwrap_fn(t, fn): + if isinstance(t, MapTensor): + with no_dispatch(): + return fn(t.elems) + else: + return None + +def wrap(t): + if isinstance(t, torch.Tensor): + return MapTensor(t) + else: + return t + +def ops_impl(cls, func, types, args, kwargs=None): + + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + + if func == torch.ops.aten.native_layer_norm.default: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + norm_res = func(*unwrapped_args) + assert len(norm_res) == 3 + return tuple(wrap(a) for a in norm_res) + + # TODO: I guess if being added against something higher dim + # we should increase dim overall? + if func == torch.ops.aten.add.Tensor: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + # print("unwrapped_args") + # print([type(a) for a in unwrapped_args]) + if not isinstance(args[0], MapTensor) and isinstance(args[1], MapTensor): + if args[0].dim() == (args[1].dim() + 1): + return NotImplemented + # return wrap(func(unwrapped_args[0], unwrapped_args[1].unsqueeze(1))) + # print("args[0].dim(): ", args[0].dim()) + # print("args[1].dim(): ", args[1].dim()) + # print("type(args[0]): ", type(args[0])) + # print("type(args[1]): ", type(args[1])) + # TODO: THIS GETS CALLED??? + return NotImplemented + pass + + if func in [torch.ops.aten.cat.default, torch.ops.aten.stack.default]: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + # TODO: Use MapTensor type for filter + # First argument's dim + dim = unwrapped_args[0][0].dim() + size = unwrapped_args[0][0].size() + for a in unwrapped_args[0]: + if a.dim() > dim: + dim = a.dim() + size = a.size() + new_args = [] + for a in unwrapped_args[0]: + if a.dim() == dim: + new_args.append(a) + else: + assert a.dim() + 1 == dim + new_args.append(a.unsqueeze(0).expand((size[0],) + a.size())) + return wrap(func(new_args, wrap_dim(unwrapped_args[1], dim - 1) + 1)) + + if func == torch.ops.aten.select.int: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + return wrap(func(unwrapped_args[0], unwrapped_args[1] + 1, unwrapped_args[2])) + + if func == torch.ops.aten.slice.Tensor: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 4, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + wrap_dim(unwrapped_args[1], dim - 1) + 1, + unwrapped_args[2], + unwrapped_args[3])) + + if func == torch.ops.aten.mean.dim: + # TODO: THIS MIGHT BE WRONG + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + assert len(unwrapped_args[1]) == 1 + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + [wrap_dim(unwrapped_args[1][0], dim - 1) + 1], + unwrapped_args[2])) + + view_ops = [torch.ops.aten._unsafe_view.default, + torch.ops.aten.expand.default] + if func in view_ops: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + input_size = unwrapped_args[0].size() + bigger_size = list(input_size[:1]) + unwrapped_args[1] + return wrap(func(unwrapped_args[0], bigger_size)) + + if func is torch.ops.aten.view.default: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + input_size = unwrapped_args[0].size() + bigger_size = list(input_size[:1]) + unwrapped_args[1] + return wrap(unwrapped_args[0].reshape(bigger_size)) + + if func in [torch.ops.aten.mm.default, torch.ops.aten.bmm.default]: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + return wrap(torch.matmul(*unwrapped_args)) + + if func in [torch.ops.aten.unsqueeze.default]: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + new_i = unwrapped_args[1] + if new_i >= 0: + new_i += 1 + return wrap(func(unwrapped_args[0], new_i)) + + if func == torch.ops.aten.addmm.default: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + return wrap(torch.matmul(unwrapped_args[1], unwrapped_args[2]) + unwrapped_args[0]) + + if func == torch.ops.aten.convolution.default: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 9, f"args: {unwrapped_args}" + a = unwrapped_args[0] + # print("0 a.size(): ", a.size()) + a = unwrapped_args[0].flatten(0, 1) + # print("1 a.size(): ", a.size()) + # TODO: It's scary that this .contiguous seems necessary, but I guess we're below composite conv + # which might expected contiguous output + resa = func(*((a,) + unwrapped_args[1:])).contiguous() + # print("0 resa.size(): ", resa.size()) + resb = resa.view((unwrapped_args[0].size(0), unwrapped_args[0].size(1)) + resa.size()[1:]) + # print("1 resb.size(): ", resb.size()) + # res_0 = func(*((unwrapped_args[0][0],) + unwrapped_args[1:])) + # if not torch.allclose(resb[0], res_0): + # print("139203") + # import pdb; pdb.set_trace() + # pass + return wrap(resb) + + if func == torch.ops.aten.upsample_bilinear2d.default: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + a = unwrapped_args[0] + # print("0 a.size(): ", a.size()) + a = unwrapped_args[0].flatten(0, 1) + # print("1 a.size(): ", a.size()) + # TODO: It's scary that this .contiguous seems necessary, but I guess we're below composite conv + # which might expected contiguous output + resa = func(*((a,) + unwrapped_args[1:])).contiguous() + # print("0 resa.size(): ", resa.size()) + resb = resa.view((unwrapped_args[0].size(0), unwrapped_args[0].size(1)) + resa.size()[1:]) + # print("1 resb.size(): ", resb.size()) + # res_0 = func(*((unwrapped_args[0][0],) + unwrapped_args[1:])) + # if not torch.allclose(resb[0], res_0): + # print("139203") + # import pdb; pdb.set_trace() + # pass + return wrap(resb) + + if func == torch.ops.aten.transpose.int: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + wrap_dim(unwrapped_args[1], dim - 1) + 1, + wrap_dim(unwrapped_args[2], dim - 1) + 1)) + + if func == torch.ops.aten._scaled_dot_product_efficient_attention.default: + assert len(args) == 5 + if all(isinstance(a, MapTensor) for a in args[:3]): + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + sdpa_res = wrap(func(unwrapped_args[0].flatten(0, 1), + unwrapped_args[1].flatten(0, 1), + unwrapped_args[2].flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4])) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if isinstance(args[0], MapTensor) and not any(isinstance(a, MapTensor) for a in args[1:]): + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 4 + assert unwrapped_args[2].dim() == 4 + a0 = unwrapped_args[0] + a1_size = unwrapped_args[1].size() + a1 = unwrapped_args[1].unsqueeze(0).expand((a0.size(0),) + a1_size) + a2 = unwrapped_args[2].unsqueeze(0).expand((a0.size(0),) + a1_size) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4])) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and (not isinstance(args[2], MapTensor))): + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 4 + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[0].size()[1:]) + a2 = unwrapped_args[2].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[2].size()[1:]) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4])) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and isinstance(args[2], MapTensor)): + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + a0_size = unwrapped_args[0].size() + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + a0_size) + a1 = unwrapped_args[1] + a2 = unwrapped_args[2] + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4])) + return (wrap(sdpa_res[0].view((a1_size[0],) + a0_size)),) + sdpa_res[1:] + return NotImplemented + + if func == torch.ops.aten._scaled_dot_product_flash_attention.default: + assert len(args) == 3 + assert len(unwrapped_kwargs) == 1 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + if all(isinstance(a, MapTensor) for a in args[:3]): + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + sdpa_res = wrap(func(unwrapped_args[0].flatten(0, 1), + unwrapped_args[1].flatten(0, 1), + unwrapped_args[2].flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if isinstance(args[0], MapTensor) and not any(isinstance(a, MapTensor) for a in args[1:]): + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 4 + assert unwrapped_args[2].dim() == 4 + a0 = unwrapped_args[0] + a1_size = unwrapped_args[1].size() + a1 = unwrapped_args[1].unsqueeze(0).expand((a0.size(0),) + a1_size) + a2 = unwrapped_args[2].unsqueeze(0).expand((a0.size(0),) + a1_size) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and (not isinstance(args[2], MapTensor))): + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 4 + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[0].size()[1:]) + a2 = unwrapped_args[2].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[2].size()[1:]) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and isinstance(args[2], MapTensor)): + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + a0_size = unwrapped_args[0].size() + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + a0_size) + a1 = unwrapped_args[1] + a2 = unwrapped_args[2] + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view((a1_size[0],) + a0_size)),) + sdpa_res[1:] + return NotImplemented + + # Only needed by inductor for compile + if func == torch.ops.aten._unsafe_index.Tensor: + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + a = unwrapped_args[0] + a = unwrapped_args[0].flatten(0, 1) + resa = func(*((a,) + unwrapped_args[1:])) + resb = resa.view((unwrapped_args[0].size(0), unwrapped_args[0].size(1)) + resa.size()[1:]) + return wrap(resb) + + forwardables = [ + torch.ops.aten.add.Tensor, + torch.ops.aten.clamp.default, + torch.ops.aten.clone.default, + torch.ops.aten.copy_.default, + torch.ops.aten.cos.default, + torch.ops.aten.div.Tensor, + torch.ops.aten.eq.Scalar, + torch.ops.aten.gelu.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.relu.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.sin.default, + torch.ops.aten.sqrt.default, + torch.ops.aten.sub.Tensor, + torch.ops.aten.unbind.int, + torch.ops.aten.where.self, + torch.ops.aten.zeros_like.default, + torch.ops.aten._to_copy.default, + ] + if func in forwardables: + return wrap(func(*unwrapped_args, **unwrapped_kwargs)) + print("WARNING! Not officially marked as forwardable: torch.ops.", func) + return wrap(func(*unwrapped_args, **unwrapped_kwargs)) + +class MapTensor(torch.Tensor): + @staticmethod + def __new__(cls, elems): + elem = elems[0] + return torch.Tensor._make_wrapper_subclass(cls, + elem.shape, + dtype=elem.dtype, + device=elem.device) + + def __init__(self, elems): + self.elems = elems + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # print("func: ", func) + res = ops_impl(cls, func, types, args, kwargs) + # if isinstance(res, torch.Tensor): + # unwrapped_args_0 = tree_map(lambda x: unwrap_i(x, 0), args) + # unwrapped_kwargs_0 = tree_map(lambda x: unwrap_i(x, 0), kwargs) + # if func == torch.ops.aten.view.default: + # res_0 = torch.ops.aten.reshape.default(*unwrapped_args_0, **unwrapped_kwargs_0) + # else: + # res_0 = func(*unwrapped_args_0, **unwrapped_kwargs_0) + # if res.elems[0].size() != res_0.size(): + # import pdb; pdb.set_trace() + # print("02390") + # if not torch.allclose(res.elems[0], res_0, atol=1e-3, rtol=1e-3): + # import pdb; pdb.set_trace() + # print("SDJFKL") + # else: + # pass + # # print("res got type: ", type(res)) + return res + + __torch_function__ = torch._C._disabled_torch_function_impl + + # flatten/unflatten is needed for compile + def __tensor_flatten__(self): + ctx = {} + inner_tensors = ["elems"] + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): + from torch._subclasses.fake_tensor import FakeTensor + + # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen] + assert len(inner_tensors) == 1, f"{inner_tensors}" + elems = inner_tensors["elems"] + + return MapTensor(elems) + + def __repr__(self): + return f"MapTensor({self.elems.size()})" + +# ts is a higher dim Tensor +def to_map_tensor(ts: torch.Tensor): + return MapTensor(ts) diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py index 52ac2267..c7b2e700 100644 --- a/sam2/modeling/position_encoding.py +++ b/sam2/modeling/position_encoding.py @@ -77,9 +77,9 @@ def encode_points(self, x, y, labels): @torch.no_grad() def forward(self, x: torch.Tensor): - cache_key = (x.shape[-2], x.shape[-1]) - if cache_key in self.cache: - return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + # cache_key = (x.shape[-2], x.shape[-1]) + # if cache_key in self.cache: + # return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) y_embed = ( torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) .view(1, -1, 1) @@ -108,7 +108,7 @@ def forward(self, x: torch.Tensor): (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - self.cache[cache_key] = pos[0] + # self.cache[cache_key] = pos[0] return pos @@ -130,7 +130,8 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 - coords = coords @ self.positional_encoding_gaussian_matrix + # coords = coords @ self.positional_encoding_gaussian_matrix + coords = torch.matmul(coords, self.positional_encoding_gaussian_matrix) coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) diff --git a/sam2/modeling/sam/mask_decoder.py b/sam2/modeling/sam/mask_decoder.py index b7c7dfdb..6bc65492 100644 --- a/sam2/modeling/sam/mask_decoder.py +++ b/sam2/modeling/sam/mask_decoder.py @@ -210,7 +210,8 @@ def predict_masks( b, c, h, w = src.shape # Run the transformer - hs, src = self.transformer(src, pos_src, tokens) + with torch.autograd.profiler.record_function("self.transformer"): + hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, s, :] mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py index 6b3bbb95..e9450449 100644 --- a/sam2/modeling/sam/prompt_encoder.py +++ b/sam2/modeling/sam/prompt_encoder.py @@ -92,12 +92,27 @@ def _embed_points( point_embedding = self.pe_layer.forward_with_coords( points, self.input_image_size ) - point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight - point_embedding[labels == 0] += self.point_embeddings[0].weight - point_embedding[labels == 1] += self.point_embeddings[1].weight - point_embedding[labels == 2] += self.point_embeddings[2].weight - point_embedding[labels == 3] += self.point_embeddings[3].weight + # point_embedding[labels == -1] = 0.0 + # point_embedding[labels == -1] += self.not_a_point_embed.weight + # point_embedding[labels == 0] += self.point_embeddings[0].weight + # point_embedding[labels == 1] += self.point_embeddings[1].weight + # point_embedding[labels == 2] += self.point_embeddings[2].weight + # point_embedding[labels == 3] += self.point_embeddings[3].weight + point_embedding = torch.where((labels == -1).unsqueeze(-1).expand_as(point_embedding), + torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, + point_embedding) + point_embedding = torch.where((labels == 0).unsqueeze(-1).expand_as(point_embedding), + point_embedding + self.point_embeddings[0].weight, + point_embedding) + point_embedding = torch.where((labels == 1).unsqueeze(-1).expand_as(point_embedding), + point_embedding + self.point_embeddings[1].weight, + point_embedding) + point_embedding = torch.where((labels == 2).unsqueeze(-1).expand_as(point_embedding), + point_embedding + self.point_embeddings[2].weight, + point_embedding) + point_embedding = torch.where((labels == 3).unsqueeze(-1).expand_as(point_embedding), + point_embedding + self.point_embeddings[3].weight, + point_embedding) return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py index b5b6fa2f..03d8f613 100644 --- a/sam2/modeling/sam/transformer.py +++ b/sam2/modeling/sam/transformer.py @@ -265,20 +265,21 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: dropout_p = self.dropout_p if self.training else 0.0 # Attention - try: - with sdp_kernel_context(dropout_p): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) - except Exception as e: - # Fall back to all kernels if the Flash attention kernel fails - warnings.warn( - f"Flash Attention kernel failed due to: {e}\nFalling back to all available " - f"kernels for scaled_dot_product_attention (which may have a slower speed).", - category=UserWarning, - stacklevel=2, - ) - global ALLOW_ALL_KERNELS - ALLOW_ALL_KERNELS = True - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + # try: + # with sdp_kernel_context(dropout_p): + # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + # except Exception as e: + # # Fall back to all kernels if the Flash attention kernel fails + # warnings.warn( + # f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + # f"kernels for scaled_dot_product_attention (which may have a slower speed).", + # category=UserWarning, + # stacklevel=2, + # ) + # global ALLOW_ALL_KERNELS + # ALLOW_ALL_KERNELS = True + # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) diff --git a/sam2/utils/amg.py b/sam2/utils/amg.py index 98684296..20a1b8e2 100644 --- a/sam2/utils/amg.py +++ b/sam2/utils/amg.py @@ -105,6 +105,42 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: for b in range(n_batches): yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] +def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + a = torch.tensor([[True]]) + if diff.is_cuda: + a = a.pin_memory().cuda() + a = a.expand_as(diff.narrow(1, 0, 1)) + diff = torch.cat([a, diff, a], dim=1) + change_indices = diff.nonzero() + + alt_lens = diff.sum(dim=1).tolist() + + all_cur_idx = change_indices[:, 1] + all_btw_idx = torch.cat([all_cur_idx[1:], all_cur_idx[:1]]) - all_cur_idx + all_btw_idx = all_btw_idx.detach().cpu().tolist() + + # Encode run length + out = [] + counts_init = (tensor[:, 0] == 0).tolist() + offset = 0 + for i, ci in zip(range(b), counts_init): + btw_idxs = all_btw_idx[offset:offset + alt_lens[i]][:-1] + offset += alt_lens[i] + counts = [] if ci else [0] + counts.extend(btw_idxs) + out.append({"size": [h, w], "counts": counts}) + + return out def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: """ diff --git a/sav_sha256sum.chk b/sav_sha256sum.chk new file mode 100644 index 00000000..15ae1e44 --- /dev/null +++ b/sav_sha256sum.chk @@ -0,0 +1,58 @@ +3187c495f5d44fdb4daba9fc617e58fbf24ca95de99085245d0175ae8eace1e5 sav_000.tar +09ffae44c756ac3a309a2d0d9169cb53e4e42b17e032fb271b75012709917c56 sav_001.tar +534c32fa2fac6db60f90604f85dd3c65407eba4064eb9d2b32f0d4a0c6c34d3e sav_002.tar +76b0e245a1acd35a0c5004f03b221657a4a78ffdfa2eff0b17838333f4904795 sav_003.tar +b167c45844e8f125d36f80296fe828609206d9160af8940e1810635bdfab118e sav_004.tar +b771390a07a518a3162ee1b493eb0654f16a27914273c95206c6e62f84c4d9bb sav_005.tar +58463b5a557f57f373a4745bea6eaa51d03fdfccfa728cd12a96abc7e7cb45cb sav_006.tar +7db46b3399189bb1c1728d96662fc1e479760f894612606d2753355e29fcaa6f sav_007.tar +191442041b350a6481d9fb67d37e776ec49cfcf8e49075e9071538a9359c03ed sav_008.tar +0ac8d12b7915044391919f98c023144b172adc92a74521fc5ac838e1edae6ffe sav_009.tar +6896b35164e00b2da926fb1c6a3f508da1c4497a06c7c2548ae6bb80a1132525 sav_010.tar +7f00d770080e866bdc6840f53b6860c98c0b39e791405cb4fc65942076f129a1 sav_011.tar +68c4e84380cc86a7984d864aee696105f77a8845525056c89b6bd26adbd119fb sav_012.tar +573edfc3325699e5f72e157db3e034964c3be150383f6d8b50f0db0b40f76aa6 sav_013.tar +f37c56bcd68b8dfceac79a3a6fa7703a48664ed0cea86b26f8183c67f624df3e sav_014.tar +010ed5c61e45f9dd99db1e35dd8c5e9b5066a2f1cfff91683c3241f0f23582d9 sav_015.tar +016d37826bf4cfdff4ca0cc36248cf86a5a479995047998b9d96afb20c26e703 sav_016.tar +b2c44dbe97e3bb5cd6e4191e5f7ce3cb1a813524f9a7dd6c41b4951c3412030b sav_017.tar +ab827c612d2135c7b6119847d5173737fa7ca5ee9757acdb628e2880c10b8aec sav_018.tar +5a72f9f393e9da0986aaf4d0a02cda4460511e5b39990d495c61953bc7ab6694 sav_019.tar +0ff3b5130bc8043e1b256904c5d64ed5413563542fe5924f0547b989f8dc502a sav_020.tar +15517ab5fa53dfc424b629e3b3be5707554c623e8e340279c76dc0c77a07d699 sav_021.tar +06d6e77a9389d9779e85fd591157f511a91cfced0e089d1810ff8dd944b882ab sav_022.tar +21d8fb80a7a1a17027e4013245fcbc789982d93bab91a40c6c3cb842c4f042db sav_023.tar +7c26912b528edcda840fa39c5870597d1d2353578c36d9390010d99ff14b1c13 sav_024.tar +1dc7269be317a7785be48f59570994b1bbb5d289a66b4507af9d5c534efa2d8d sav_025.tar +53b416935d62030e3df4b6b7d622a0a1c308ad426f77401c4ad9ea7991a720fa sav_026.tar +e42332035dc0d937c92a2de841e670f9c5c5f7d420f65ddc57b2377ab89c3c27 sav_027.tar +4938595bef62073f3d6588f97d4b2dc38062ac738098d0b49c79515de51fbd14 sav_028.tar +1012cbe35cd3c58152b349c31bd8dcb888ce5359ff982558891ae4d988f63f3f sav_029.tar +78a953775d150f2bec6d4e8fe1fbeeb3565c383ee83e25b9f22805af00696787 sav_030.tar +b05c9138bca14a2a43a47dce25860328c300843ded5c84045a079bae29dd8621 sav_031.tar +ed7ad39b0d364755b677c92dc7f07fa6a739c01707f0e72d376b3d8278ccd7df sav_032.tar +a3c0dc2251f927383a55d7876638c9e0162d07edd8dc6556dbd8d8a5972ae406 sav_033.tar +b2e7c3d83a9d6750b82b4f62c6c82ddcba291cfa56b01ed35d62282a4538dcec sav_034.tar +cb15403dc21e408e84a3a69ab1e246ee0d2feee9b3c24ea92fb4b7d7859e2fdf sav_035.tar +d94463a83e605720e6a9f3447beb9484183ca6a5cb8acdc7091b5f0ccd873cd4 sav_036.tar +7fb67971dc8cbff79bb70581ca02fc32bd2ba500b03bd2c217764cc415584d76 sav_037.tar +1e08962f47e37eb61c23a7cafbbd6115ed64e098174d21ee591e09bc65915829 sav_038.tar +1e914968250251dd24b94ac5520789c11acaf3873c313e18a2285ebb9361ca75 sav_039.tar +c44d53ea991014b31c4710f3796d91c7ee7d9561c6bfa5278b6b7301f587f7b5 sav_040.tar +35ed0392f9520939a5585a628584e7d7ed03d11ec8ffd70de146752bf16f7e83 sav_041.tar +7da15eb094ae66b45f508a20e98dc3d7807382518d83b4fcfe7779ef01bacfa0 sav_042.tar +d1c00bb0b9606dab519919a93cfade99113e0b2bdb9130ee1ab72deda40b5d03 sav_043.tar +cc6458c65cfcaa643a4c2108b21c4652d99c29e43572660be435c8aec2e9f756 sav_044.tar +e764cf49536d9a979f1ba746b4fbc695638e188de369fc5c9a1f65d03c57268e sav_045.tar +a4225cebc6b7027820dd4b9de9ad9acdf76db41aa8f472560fae705c7e4afa46 sav_046.tar +ea5c26893a191105ce424972694929502d323403f2317e716b7f38dca6e79276 sav_047.tar +661280ba415b48ca57fdca50077261e65ff8e132778011feee30f1ff06c877b0 sav_048.tar +cee6855a8b837587ee095202121fd53d6a5dac63f63a0eb151ff46b7ca99c9af sav_049.tar +2b742314320748d2974002e9851ecd67849e0e79bc513be1dc7cb9f7ee4dd73a sav_050.tar +d8c9f8c54f383e314dede43297906dd0f1b5b25605284e42d1cd293135a64fc6 sav_051.tar +96897210188fbb7ed68826e46616956deb932e3c7feaca6c9226945f38d3a338 sav_052.tar +f674d34b87932ca1006fef93b6fe8e54630e7391225c46c57c10a74e46b1c429 sav_053.tar +435fe135c350332cad7d035a5ea7af3f92c50426d96e4c505c837166ac0e9b7b sav_054.tar +24c8d79f6b2a3ef6a45c68be817ea1eed9b9b08dba57d122e41d3cd83fcc4358 sav_055.tar +81b83d9a24aec921e39e532e933a86496d12c001ebb733591c9f29a3f205a83b sav_val.tar +74a48ef0448f6cf52f17ea8fded4930059d9a2137b7584bb0cff1a948fef8cd2 sav_test.tar diff --git a/setup.py b/setup.py index ebef97cd..c4dab1d7 100644 --- a/setup.py +++ b/setup.py @@ -22,13 +22,13 @@ # Required dependencies REQUIRED_PACKAGES = [ - "torch>=2.3.1", - "torchvision>=0.18.1", - "numpy>=1.24.4", - "tqdm>=4.66.1", - "hydra-core>=1.3.2", - "iopath>=0.1.10", - "pillow>=9.4.0", + # "torch>=2.3.1", + # "torchvision>=0.18.1", + # "numpy>=1.24.4", + # "tqdm>=4.66.1", + # "hydra-core>=1.3.2", + # "iopath>=0.1.10", + # "pillow>=9.4.0", ] EXTRA_PACKAGES = { @@ -139,8 +139,8 @@ def get_ext_filename(self, ext_name): packages=find_packages(exclude="notebooks"), package_data={"": ["*.yaml"]}, # SAM 2 configuration files include_package_data=True, - install_requires=REQUIRED_PACKAGES, - extras_require=EXTRA_PACKAGES, + # install_requires=REQUIRED_PACKAGES, + # extras_require=EXTRA_PACKAGES, python_requires=">=3.10.0", ext_modules=get_extensions(), cmdclass=cmdclass,