Skip to content

OpenBMB/VisRAG

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

VisRAG: Vision-based Retrieval-augmented Generation on Multi-modality Documents

VisRAG is a novel vision-language model (VLM)-based RAG pipeline. In this pipeline, instead of first parsing the document to obtain text, the document is directly embedded using a VLM as an image and then retrieved to enhance the generation of a VLM.Compared to traditional text-based RAG, VisRAG maximizes the retention and utilization of the data information in the original documents, eliminating the information loss introduced during the parsing process.

VisRAG Pipeline

VisRAG-Ret

VisRAG-Ret is a document embedding model built on MiniCPM-V 2.0, a vision-language model that integrates SigLIP as the vision encoder and MiniCPM-2B as the language model.

VisRAG-Gen

In the paper, We use MiniCPM-V 2.0, MiniCPM-V 2.6 and GPT-4o as the generators. Actually you can use any VLMs you like!

Setup

conda create --name VisRAG python==3.10.8
conda install nvidia/label/cuda-11.8.0::cuda-toolkit
cd VisRAG
pip install -r requirements.txt
pip install -e .
cd timm_modified
pip install -e .
cd ..

Training

VisRAG-Ret

Our training dataset of 362,110 Query-Document (Q-D) Pairs for VisRAG-Ret is comprised of train sets of openly available academic datasets (34%) and a synthetic dataset made up of pages from web-crawled PDF documents and augmented with VLM-generated (GPT-4o) pseudo-queries (66%).

Data coming soon

bash scripts/train_retriever/train.sh 2048 16 8 0.02 1 true false config/deepspeed.json 1e-5 false wmean causal 1 true 2 false <model_path> <repository_name>

<repository_name> can be 'openbmb/VisRAG-Ret-Train-In-domain-data' or 'openbmb/VisRAG-Ret-Train-Synthetic-data'.

VisRAG-Gen

The generation part does not use any fine-tuning; we directly use off-the-shelf LLMs/VLMs for generation.

Evaluation

VisRAG-Ret

Data coming soon

bash scripts/eval_retriever/eval.sh 512 2048 16 8 wmean causal ArxivQA,ChartQA,MP-DocVQA,InfoVQA,PlotQA,SlideVQA <ckpt_path>

The parameters mentioned above is what we use in our paper, you can use them to reproduce the results in the paper.

VisRAG-Gen

Coming soon

Usage

VisRAG-Ret

Model on Hugging Face: https://huggingface.co/openbmb/VisRAG-Ret

from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn.functional as F
from PIL import Image
import os

def weighted_mean_pooling(hidden, attention_mask):
    attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
    s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
    d = attention_mask_.sum(dim=1, keepdim=True).float()
    reps = s / d
    return reps

@torch.no_grad()
def encode(text_or_image_list):
    
    if (isinstance(text_or_image_list[0], str)):
        inputs = {
            "text": text_or_image_list,
            'image': [None] * len(text_or_image_list),
            'tokenizer': tokenizer
        }
    else:
        inputs = {
            "text": [''] * len(text_or_image_list),
            'image': text_or_image_list,
            'tokenizer': tokenizer
        }
    outputs = model(**inputs)
    attention_mask = outputs.attention_mask
    hidden = outputs.last_hidden_state

    reps = weighted_mean_pooling(hidden, attention_mask)   
    embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
    return embeddings

tokenizer = AutoTokenizer.from_pretrained("openbmb/VisRAG-Ret", trust_remote_code=True)
model = AutoModel.from_pretrained("openbmb/VisRAG-Ret", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.eval()

script_dir = os.path.dirname(os.path.realpath(__file__))
queries = ["What does a dog look like?"]
passages = [
    Image.open(os.path.join(script_dir, 'test_image/cat.jpeg')).convert('RGB'),
    Image.open(os.path.join(script_dir, 'test_image/dog.jpg')).convert('RGB'),
]

INSTRUCTION = "Represent this query for retrieving relevant documents: "
queries = [INSTRUCTION + query for query in queries]

embeddings_query = encode(queries)
embeddings_doc = encode(passages)

scores = (embeddings_query @ embeddings_doc.T)
print(scores.tolist())

License

  • The code in this repo is released under the Apache-2.0 License.
  • The usage of VisRAG-Ret model weights must strictly follow MiniCPM Model License.md.
  • The models and weights of VisRAG-Ret are completely free for academic research. After filling out a "questionnaire" for registration, VisRAG-Ret weights are also available for free commercial use.

Contact

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages