File size: 2,452 Bytes
7262fda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-

import torch
import numpy as np
from PIL import Image
from dataclasses import dataclass
from torchvision.transforms import Normalize
from transformers import CLIPModel, CLIPTokenizer
from transformers.utils import ModelOutput
from typing import Iterable, Optional, Union, List


ImageType = Union[np.ndarray, torch.Tensor, Image.Image]


@dataclass
class CLIPEmbedOutput(ModelOutput):
    last_hidden_state: torch.FloatTensor = None
    pooler_output: torch.FloatTensor = None
    embeds: torch.FloatTensor = None


class CLIPEncoder(torch.nn.Module):

    def __init__(self, model_path="openai/clip-vit-base-patch32"):

        super().__init__()

        # Load the CLIP model and processor
        self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
        self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        self.model.training = False
        for p in self.model.parameters():
            p.requires_grad = False

    @torch.no_grad()
    def encode_image(self, images: Iterable[Optional[ImageType]]):
        pixel_values = self.image_preprocess(images)

        vision_outputs = self.model.vision_model(pixel_values=pixel_values)

        pooler_output = vision_outputs[1]  # pooled_output
        image_features = self.model.visual_projection(pooler_output)

        visual_embeds = CLIPEmbedOutput(
            last_hidden_state=vision_outputs.last_hidden_state,
            pooler_output=pooler_output,
            embeds=image_features
        )

        return visual_embeds

    @torch.no_grad()
    def encode_text(self, texts: List[str]):
        text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")

        text_outputs = self.model.text_model(input_ids=text_inputs)

        pooler_output = text_outputs[1]  # pooled_output
        text_features = self.model.text_projection(pooler_output)

        text_embeds = CLIPEmbedOutput(
            last_hidden_state=text_outputs.last_hidden_state,
            pooler_output=pooler_output,
            embeds=text_features
        )

        return text_embeds

    def forward(self,
                images: Iterable[Optional[ImageType]],
                texts: List[str]):

        visual_embeds = self.encode_image(images)
        text_embeds = self.encode_text(texts)

        return visual_embeds, text_embeds