Skip to content

[NeurIPS'24] This repository is the implementation of "SpatialRGPT: Grounded Spatial Reasoning in Vision Language Models"

License

Notifications You must be signed in to change notification settings

AnjieCheng/SpatialRGPT

Repository files navigation

SpatialRGPT: Grounded Spatial Reasoning in Vision Language Models (NeurIPS'24)

Code License Model License Python 3.10+

arxiv / Huggingface


💡 Introduction

SpatialRGPT: Grounded Spatial Reasoning in Vision-Language Models
An-Chieh Cheng, Hongxu (Danny) Yin, Yang Fu, Qiushan Guo, Ruihan Yang, Jan Kautz, Xiaolong Wang, Sifei Liu

SpatialRGPT is a powerful vision-language model adept at understanding both 2D and 3D spatial arrangements. It can process any region proposal, such as boxes or masks, and provide answers to complex spatial reasoning questions.


📢 News

  • Oct-07-24- SpatialRGPT code/dataset/benchmark released! 🔥
  • Sep-25-24- We're thrilled to share that SpatialRGPT has been accepted to NeurIPS 2024! 🎊

Installation

To build environment for training SpatialRGPT, please run the following:

./environment_setup.sh srgpt
conda activate srgpt

Gradio Demo

To run the Gradio demo for SpatialRGPT, please follow these steps. Due to pydantic version conflicts, the demo environment is not compatible with the training environment. Therefore, a separate environment will need to be created for the Gradio demo.

  1. Build the environment.

    ./environment_setup.sh srgpt-demo
    conda activate srgpt-demo
    pip install gradio==4.27 deepspeed==0.13.0 gradio_box_promptable_image segment_anything_hq
    pip install -U 'git+https://github.com/facebookresearch/detectron2.git@ff53992b1985b63bd3262b5a36167098e3dada02'

    If you run into an error with the detectron2 installation, it could be because CUDA_HOME is not set. To fix this, export CUDA_HOME to your local CUDA path. See details in this issue.

  2. Clone the Depth-Anything repository and download the necessary checkpoint:

    git clone https://github.com/LiheYoung/Depth-Anything.git
    wget https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth

    Place depth_anything_vitl14.pth under Depth-Anything/checkpoints, and set the path to the environment variable. For example:

    export DEPTH_ANYTHING_PATH=/YOUR_OWN_PATH/Depth-Anything
  3. Download SAM-HQ checkpoint from here, and set the path to the environment variable. For example:

    export SAM_CKPT_PATH=/YOUR_OWN_PATH/sam_hq_vit_h.pth
  4. Launch Gradio server. You can use your own checkpoint, or use a8cheng/SpatialRGPT-VILA1.5-8B

    cd demo
    python gradio_web_server_multi.py --model-path PATH_TO_CHECKPOINT

Training

SpatialRGPT follows VILA training, which contains three steps. We provide training script for three different LLM models, sheared_3b, llama2_7b, llama3_8b. You can find the training scripts for each stage in the scripts/srgpt folder.


Open Spatial Dataset

Please download the Open Spatial Dataset from huggingface, and modify the path in llava/data/dataset_mixture.py.

For raw images, please download OpenImages from OpenImagesV7. To process the rgb images into depth, we use DepthAnythingV2 and save the depth with the following function:

Click to expand
def save_raw_16bit(depth, fpath, height, width):
  depth = F.interpolate(depth[None, None], (height, width), mode='bilinear', align_corners=False)[0, 0]
  depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
  depth = depth.cpu().numpy().astype(np.uint8)
  colorized_depth = np.stack([depth, depth, depth], axis=-1)

  depth_image = Image.fromarray(colorized_depth)
  depth_image.save(fpath)

Dataset Synthesis Pipeline

We've also made the dataset synthesis pipeline available. You can find the code and instructions in the dataset_pipeline folder. Please note that some of the packages we use have had version updates, and we've migrated to their latest versions. This may result in some bugs. Feel free to report any issues or unexpected results you encounter.

Wis3D Demo


Evaluations

Our evaluation scripts takes the following arguments, PATH_TO_CKPT, CKPT_NAME, CONV_TYPE.

  • PATH_TO_CKPT refers to the location of the checkpoint you want to evaluate.
  • CKPT_NAME specifies the folder that will be created in the eval_out directory, where the evaluation results will be stored.
  • Make sure that CONV_TYPE matches the conversation type used in the checkpoint. For llama3_8b, please use llama_3.

Region Classification

First, prepare the evaluation annotation following RegionCLIP. Then, use scripts/srgpt/eval/coco_cls.sh PATH_TO_CKPT CKPT_NAME CONV_TYPE.

SpatialRGPT-Bench Evaluation

First, download the images from omni3d, following there instructions. Then download annotations from https://huggingface.co/datasets/a8cheng/SpatialRGPT-Bench. Modify the path in scripts/srgpt/eval/srgpt_bench.sh to corresponding paths.

Note that for SpatialRGPT-Bench, you need to clone the Depth-Anything repository and download the necessary checkpoint:

git clone https://github.com/LiheYoung/Depth-Anything.git
wget https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth

Place depth_anything_vitl14.pth under Depth-Anything/checkpoints, and set the path to the environment variable.

export DEPTH_ANYTHING_PATH="PATH_TO_DEPTHANYTHING"

Then use scripts/srgpt/eval/srgpt_bench.sh PATH_TO_CKPT CKPT_NAME CONV_TYPE.

General VLM Benchmarks

Our code is compatible with VILA's evaluation scripts. See VILA/evaluations for details.


📜 Citation

  @inproceedings{cheng2024spatialrgpt,
          title={SpatialRGPT: Grounded Spatial Reasoning in Vision-Language Models},
          author={Cheng, An-Chieh and Yin, Hongxu and Fu, Yang and Guo, Qiushan and Yang, Ruihan and Kautz, Jan and Wang, Xiaolong and Liu, Sifei},
          booktitle={NeurIPS},
          year={2024}
  }

🙏 Acknowledgement

We have used code snippets from different repositories, especially from: VILA, Omni3D, GLaMM, VQASynth, and ConceptGraphs. We would like to acknowledge and thank the authors of these repositories for their excellent work.

About

[NeurIPS'24] This repository is the implementation of "SpatialRGPT: Grounded Spatial Reasoning in Vision Language Models"

Topics

Resources

License

Stars

Watchers

Forks