File size: 6,543 Bytes
08545c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

import os 
import numpy as np
import pickle
import torch
import transformers
from PIL import Image
from open_clip import create_model_from_pretrained, create_model_and_transforms
import json 

# XLM model functions 
from multilingual_clip import pt_multilingual_clip

from model_loading import load_model



class CustomDataSet(torch.utils.data.Dataset):
    def __init__(self, main_dir, compose, image_name_list):
        self.main_dir = main_dir
        self.transform = compose
        self.total_imgs = image_name_list

    def __len__(self):
        return len(self.total_imgs)

    def get_image_name(self, idx):

        return self.total_imgs[idx]

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc)

        return self.transform(image)


def features_pickle(file_path=None):

    with open(file_path, 'rb') as handle:
        features_pickle = pickle.load(handle)

    return features_pickle


def dataset_loading():

    with open("photos/en_ar_XTD10_edited_v2.jsonl") as filino:


        data = [json.loads(file_i) for file_i in filino]

    sorted_data = sorted(data, key=lambda x: x['id'])

    image_name_list = [lin["image_name"] for lin in sorted_data]


    return sorted_data, image_name_list 


def text_encoder(language_model, text):
    """Normalize the text embeddings"""
    embedding = language_model(text)
    norm_embedding = embedding / np.linalg.norm(embedding)

    return embedding, norm_embedding


def compare_embeddings(logit_scale, img_embs, txt_embs):
  
  image_features = img_embs / img_embs.norm(dim=-1, keepdim=True)

  text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)

  logits_per_text = logit_scale * text_features @ image_features.t()

  return logits_per_text

# Done 
def compare_embeddings_text(full_text_embds, txt_embs):
  
  full_text_embds_features = full_text_embds / full_text_embds.norm(dim=-1, keepdim=True)

  text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)

  logits_per_text_full = text_features @ full_text_embds_features.t()

  return logits_per_text_full



def find_image(language_model,clip_model, text_query, dataset, image_features, text_features_new,sorted_data, num=1):

    embedding, _  = text_encoder(language_model, text_query)

    logit_scale = clip_model.logit_scale.exp().float().to('cpu')

    language_logits, text_logits = {}, {}

    language_logits["Arabic"] = compare_embeddings(logit_scale, torch.from_numpy(image_features), torch.from_numpy(embedding))

    text_logits["Arabic_text"] = compare_embeddings_text(torch.from_numpy(text_features_new), torch.from_numpy(embedding))

    
    for _, txt_logits in language_logits.items():

        probs = txt_logits.softmax(dim=-1).cpu().detach().numpy().T

        file_paths = []
        labels, json_data = {}, {}

        for i in range(1, num+1):
            idx = np.argsort(probs, axis=0)[-i, 0]
            path = 'photos/XTD10_dataset/' + dataset.get_image_name(idx)
                    
            path_l = (path,f"{sorted_data[idx]['caption_ar']}")

            labels[f" Image # {i}"] = probs[idx]
            json_data[f" Image # {i}"] = sorted_data[idx]

            file_paths.append(path_l)


    json_text = {} 

    for _, txt_logits_full in text_logits.items():

        probs_text = txt_logits_full.softmax(dim=-1).cpu().detach().numpy().T

        for j in range(1, num+1):

            idx = np.argsort(probs_text, axis=0)[-j, 0]
            json_text[f" Text # {j}"] = sorted_data[idx]

    return file_paths, labels, json_data, json_text



class AraClip():
    def __init__(self):

        self.text_model = load_model('bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M', in_features= 768, out_features=768)
        self.language_model = lambda queries: np.asarray(self.text_model(queries).detach().to('cpu')) 
        self.clip_model, self.compose = create_model_from_pretrained('hf-hub:timm/ViT-B-16-SigLIP-512')
        self.sorted_data, self.image_name_list = dataset_loading()

    def load_images(self):
        # Return the features of the text and images
        image_features_new = features_pickle('cashed_pickles/image_features_XTD_1000_images_arabert_siglib_best_model.pickle')
        return image_features_new
    
    def load_text(self):
        text_features_new = features_pickle('cashed_pickles/text_features_XTD_1000_images_arabert_siglib_best_model.pickle')
        return text_features_new
    
    def load_dataset(self):
        dataset = CustomDataSet("photos/XTD10_dataset", self.compose, self.image_name_list)
        return dataset


araclip = AraClip()

def predict(text, num):

    image_paths, labels, json_data, json_text = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_dataset(), araclip.load_images() , araclip.load_text(), araclip.sorted_data, num=int(num))

    return image_paths, labels, json_data, json_text


class Mclip():
    def __init__(self) -> None:

    
        self.tokenizer_mclip = transformers.AutoTokenizer.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-16Plus')
        self.text_model_mclip = pt_multilingual_clip.MultilingualCLIP.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-16Plus')
        self.language_model_mclip = lambda queries:  np.asarray(self.text_model_mclip.forward(queries, self.tokenizer_mclip).detach().to('cpu'))  
        self.clip_model_mclip, _, self.compose_mclip = create_model_and_transforms('ViT-B-16-plus-240', pretrained="laion400m_e32")
        self.sorted_data, self.image_name_list = dataset_loading()

    def load_images(self):
        # Return the features of the text and images
        image_features_mclip = features_pickle('cashed_pickles/image_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle')
        return image_features_mclip
    
    def load_text(self):
        text_features_new_mclip = features_pickle('cashed_pickles/text_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle')
        return text_features_new_mclip
    
    def load_dataset(self):
        dataset_mclip = CustomDataSet("photos/XTD10_dataset", self.compose_mclip, self.image_name_list)
        return dataset_mclip


mclip = Mclip()

def predict_mclip(text, num):

    image_paths, labels, json_data, json_text = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_dataset() , mclip.load_text() , mclip.load_text() , mclip.sorted_data , num=int(num))

    return image_paths, labels, json_data, json_text