Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel | |
import torch | |
import io | |
from PIL import Image | |
import os | |
from cryptography.fernet import Fernet | |
from google.cloud import storage | |
import pinecone | |
import json | |
# decrypt Storage Cloud credentials | |
fernet = Fernet(os.environ['DECRYPTION_KEY']) | |
with open('cloud-storage.encrypted', 'rb') as fp: | |
encrypted = fp.read() | |
creds = json.loads(fernet.decrypt(encrypted).decode()) | |
# then save creds to file | |
with open('cloud-storage.json', 'w', encoding='utf-8') as fp: | |
fp.write(json.dumps(creds, indent=4)) | |
# connect to Cloud Storage | |
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json' | |
storage_client = storage.Client() | |
bucket = storage_client.get_bucket('diffusion-search') | |
# get api key for pinecone auth | |
PINECONE_KEY = os.environ['PINECONE_KEY'] | |
index_id = "diffusion-search" | |
# init connection to pinecone | |
pinecone.init( | |
api_key=PINECONE_KEY, | |
environment="us-west1-gcp" | |
) | |
if index_id not in pinecone.list_indexes(): | |
raise ValueError(f"Index '{index_id}' not found") | |
index = pinecone.Index(index_id) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Using '{device}' device...") | |
# init all of the models and move them to a given GPU | |
# if you have CUDA or MPS, set it to the active device like this | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_id = "openai/clip-vit-base-patch32" | |
# we initialize a tokenizer, image processor, and the model itself | |
tokenizer = CLIPTokenizerFast.from_pretrained(model_id) | |
model = CLIPModel.from_pretrained(model_id).to(device) | |
missing_im = Image.open('missing.png') | |
threshold = 0.85 | |
def encode_text(text: str): | |
# create transformer-readable tokens | |
inputs = tokenizer(text, return_tensors="pt").to(device) | |
text_emb = model.get_text_features(**inputs).cpu().detach().tolist() | |
return text_emb | |
def prompt_query(text: str): | |
print(f"Running prompt_query('{text}')") | |
embeds = encode_text(text) | |
try: | |
xc = index.query(embeds, top_k=30, include_metadata=True) | |
except Exception as e: | |
print(f"Error during query: {e}") | |
# reinitialize connection | |
pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') | |
index2 = pinecone.Index(index_id) | |
try: | |
xc = index2.query(embeds, top_k=30, include_metadata=True) | |
print("Reinitialized query successful") | |
except Exception as e: | |
raise ValueError(e) | |
scores = [round(match['score'], 2) for match in xc['matches']] | |
ids = [match['id'] for match in xc['matches']] | |
return ids | |
def get_image(url: str): | |
blob = bucket.blob(url).download_as_string() | |
blob_bytes = io.BytesIO(blob) | |
im = Image.open(blob_bytes) | |
return im | |
def test_image(_id, image): | |
try: | |
image.save('tmp.png') | |
return True | |
except OSError: | |
# delete corrupted file from pinecone and cloud | |
index.delete(ids=[_id]) | |
bucket.blob(f"images/{_id}.png").delete() | |
return False | |
def prompt_image(text: str): | |
embeds = encode_text(text) | |
try: | |
xc = index.query( | |
embeds, top_k=9, include_metadata=True, | |
filter={"image_nsfw": {"$lt": 0.5}} | |
) | |
except Exception as e: | |
print(f"Error during query: {e}") | |
# reinitialize connection | |
pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') | |
index2 = pinecone.Index(index_id) | |
try: | |
xc = index2.query( | |
embeds, top_k=9, include_metadata=True, | |
filter={"image_nsfw": {"$lt": 0.5}} | |
) | |
print("Reinitialized query successful") | |
except Exception as e: | |
raise ValueError(e) | |
scores = [match['score'] for match in xc['matches']] | |
ids = [match['id'] for match in xc['matches']] | |
images = [] | |
for _id in ids: | |
try: | |
image_url = f"images/{_id}.png" | |
blob = bucket.blob(image_url).download_as_string() | |
blob_bytes = io.BytesIO(blob) | |
im = Image.open(blob_bytes) | |
if test_image(_id, im): | |
images.append(im) | |
else: | |
images.append(missing_im) | |
except ValueError: | |
print(f"ValueError: '{image_url}'") | |
return images, scores | |
# __APP FUNCTIONS__ | |
def set_suggestion(text: str): | |
return gr.TextArea.update(value=text[0]) | |
def set_images(text: str): | |
images, scores = prompt_image(text) | |
return gr.Gallery.update(value=images) | |
# __CREATE APP__ | |
demo = gr.Blocks() | |
with demo: | |
gr.HTML( | |
""" | |
<img src="https://huggingface.co/spaces/pinecone/diffusion-image-search/resolve/main/pine-trees-collage.png" /> | |
<style> | |
.parallax { | |
/* The image used */ | |
background-image: url("https://huggingface.co/spaces/pinecone/diffusion-image-search/resolve/main/pine-trees-collage.png"); | |
/* Create the parallax scrolling effect */ | |
background-attachment: fixed; | |
background-position: center; | |
background-repeat: no-repeat; | |
background-size: cover; | |
} | |
</style> | |
<!-- Container element --> | |
<div class="parallax"></div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.TextArea( | |
value="space dogs", | |
placeholder="Something cool to search for...", | |
interactive=True | |
) | |
search = gr.Button(value="Search!") | |
gr.Markdown( | |
""" | |
#### Search through 10K images generated by AI | |
This app demonstrates the idea of text-to-image search. The search process | |
uses an AI model that understands the *meaning* of text and images to identify | |
images that best align to a search prompt. | |
πͺ [*Built with the OP Stack*](https://gkogan.notion.site/gkogan/The-OP-Stack-aafcab0005e3445a8ad8491aac80446c) | |
""" | |
) | |
# results column | |
with gr.Column(): | |
pics = gr.Gallery() | |
pics.style(grid=3) | |
# search event listening | |
try: | |
search.click(set_images, prompt, pics) | |
except OSError: | |
print("OSError") | |
demo.launch() |