File size: 2,158 Bytes
30099ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from dataclasses import dataclass
from typing import Any
import japanese_clip as ja_clip
from s3_session import Bucket
from PIL import Image
import uuid
from db_session import get_db


@dataclass
class MLModel:
    tokenizer: Any = None
    model: Any = None
    preprocess: Any = None
    bucket: Any = None

    def __post_init__(self):
        tokenizer = ja_clip.load_tokenizer()
        model, preprocess = ja_clip.load(
            "rinna/japanese-clip-vit-b-16", cache_dir="/tmp/japanese_clip", device="cpu"
        )
        self.tokenizer = tokenizer
        self.model = model
        self.preprocess = preprocess
        self.bucket = Bucket()

    def save(self, image_path: str):
        pillow_iamge = Image.open(image_path)
        image = self.preprocess(pillow_iamge).unsqueeze(0).to("cpu")
        image_features = self.model.get_image_features(image)
        image_uuid = str(uuid.uuid4())

        # media upload
        self.bucket.upload_file(pillow_iamge, image_uuid)

        # db insert
        db = get_db()
        result = db["embedding"].insert_one(
            {"uuid": image_uuid, "vectorField": image_features[0].tolist()}
        )
        return result.inserted_id

    def search(self, prompt: str):
        db = get_db()
        encodings = ja_clip.tokenize(
            texts=[prompt], max_seq_len=77, device="cpu", tokenizer=self.tokenizer
        )
        text_features = self.model.get_text_features(**encodings)
        pipeline = [
            {
                "$vectorSearch": {
                    "index": "vector_index",
                    "path": "vectorField",
                    "queryVector": text_features[0].tolist(),
                    "numCandidates": 150,
                    "limit": 10,
                }
            },
            {
                "$project": {
                    "_id": {"$toString": "$_id"},
                    "uuid": 1,
                    "score": {"$meta": "vectorSearchScore"},
                }
            },
        ]
        result = db["embedding"].aggregate(pipeline)
        urls = [self.bucket.get_presigned_url(x["uuid"]) for x in result]
        return urls