PyTorch
English
Tevatron
phi3_v
vidore
custom_code
MrLight's picture
Update README.md
f5c38c6 verified
|
raw
history blame
No virus
3.9 kB
metadata
license: mit
language:
  - en
tags:
  - vidore
datasets:
  - Tevatron/docmatix-ir
  - HuggingFaceM4/Docmatix
library_name: Tevatron

DSE-Phi3-Docmatix-V1.0

DSE-Phi3-Docmatix-V1.0 is a bi-encoder model designed to encode document screenshots into dense vectors for document retrieval. The Document Screenshot Embedding (DSE) approach captures documents in their original visual format, preserving all information such as text, images, and layout, thus avoiding tedious parsing and potential information loss.

The model, Tevatron/dse-phi3-docmatix-v1.0, is trained using the Tevatron/docmatix-ir dataset, a variant of HuggingFaceM4/Docmatix specifically adapted for training PDF retrievers with Vision Language Models in open-domain question answering scenarios. For more information on dataset filtering and hard negative mining, refer to the docmatix-ir dataset page.

How to Use the Model

Load the Model and Processor

import torch
from transformers import AutoProcessor, AutoModelForCausalLM, AutoConfig

processor = AutoProcessor.from_pretrained('microsoft/Phi-3-vision-128k-instruct', trust_remote_code=True)
config = AutoConfig.from_pretrained('microsoft/Phi-3-vision-128k-instruct', trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False)
model = AutoModelForCausalLM.from_pretrained('Tevatron/dse-phi3-docmatix-v1.0', trust_remote_code=True, config=config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16).to('cuda:0')

def get_embedding(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    sequence_lengths = attention_mask.sum(dim=1) - 1
    bs = last_hidden_state.shape[0]
    reps = last_hidden_state[torch.arange(bs, device=last_hidden_state.device), sequence_lengths]
    reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps

Encode Text Query

queries = ["query: Where can we find Llama?", "query: What is the LLaMA model?"]
query_inputs = processor(queries, return_tensors="pt", padding="longest", max_length=128, truncation=True).to('cuda:0')
output = model(**query_inputs, return_dict=True, output_hidden_states=True)
query_embeddings = get_embedding(output.hidden_states[-1], query_inputs["attention_mask"])

Encode Document Screenshot

from PIL import Image

passage_image1 = Image.open("path/to/your/image1.png")
passage_image2 = Image.open("path/to/your/image2.png")
passage_images = [passage_image1, passage_image2]
passage_prompts = ["\nWhat is shown in this image?</s>", "\nWhat is shown in this image?</s>"]

passage_inputs = processor(passage_prompts, images=passage_images, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0')
output = model(**passage_inputs, return_dict=True, output_hidden_states=True)
doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"])

Compute Similarity

from torch.nn.functional import cosine_similarity

similarities = cosine_similarity(query_embeddings, doc_embeddings)
print(similarities)

Encode Document Text

This DSE checkpoint is warm-up with Tevatron/msmarco-passage-aug, thus the model can also effectively encode document as text input.

passage_prompts = ["Llama is in Aferica</s>", "LLaMA is an LLM released by Meta.</s>"]

passage_inputs = processor(passage_prompts, images=None, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0')
output = model(**passage_inputs, return_dict=True, output_hidden_states=True)
doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"])

similarities = cosine_similarity(query_embeddings, doc_embeddings)
print(similarities)