import os os.system('pip install git+https://github.com/openai/CLIP.git') import gradio as gr import datetime import PIL from PIL import Image from transformers import BlipProcessor, BlipForConditionalGeneration from transformers import Blip2Processor, Blip2ForConditionalGeneration import torch import clip import torch.nn as nn from torchvision.transforms import transforms device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class OpenaAIClip(nn.Module): def __init__(self, arch="resnet50", modality="image"): super().__init__() self.model = None self.modality = modality if arch == "resnet50": self.model, _ = clip.load("RN50") if self.modality == "image": for name, param in self.model.named_parameters(): if "visual" in name: #print("Unfreezing layer: ", name) param.requires_grad = True else: param.requires_grad = False self.fc = nn.Identity() def forward(self, image, text=None): image_features = self.model.encode_image(image) if self.modality == "image+text": text = clip.tokenize(text, truncate=True).to(device) text_features = self.model.encode_text(text) else: return self.fc(image_features) combined_features = torch.cat((image_features, text_features), dim=1) return self.fc(combined_features) def preprocessing(img, size): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) data_transforms = transforms.Compose([ transforms.Resize((size, size)), transforms.ToTensor(), normalize]) img = data_transforms(img) return img def get_model(model_path,modality): if modality == "Image": model = OpenaAIClip(arch="resnet50", modality="image") dim_mlp = 1024 fc_units = [512] model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1), nn.Sigmoid()) elif modality == "Image+Text": model = OpenaAIClip(arch="resnet50", modality="image+text") dim_mlp = 2048 fc_units = [1024] model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1), nn.Sigmoid()) checkpoint_dict = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(checkpoint_dict['state_dict']) model.eval() return model def get_blip_model(): processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") #processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b") model.eval() return processor, model def get_caption(image): processor, model = get_blip_model() inputs = processor(images=image, return_tensors="pt") outputs = model.generate(**inputs.to(device)) caption = processor.decode(outputs[0], skip_special_tokens=True) return caption def predict(img, caption): now = datetime.datetime.now() print(now) if img is not None: print(caption) if caption is None or caption == "": caption = get_caption(img) print("Generated caption-->", caption) else: print("User input caption-->", caption) img.save("models/"+str(now)+".png") prediction=[] models_list = ['models/clip-sd.pth', 'models/clip-glide.pth', 'models/clip-ld.pth'] modality = "Image+Text" for i, model_path in enumerate(models_list): model = get_model(model_path, modality) tensor = preprocessing(img, 224) input_tensor = tensor.view(1, 3, 224, 224) with torch.no_grad(): out = model(input_tensor, caption) print(models_list[i], ' ----> ', out) prediction.append(out.item()) # Count the number of predictions that are greater than or equal to 0.5 count_ones = sum(1 for p in prediction if p >= 0.5) if count_ones > len(prediction) / 2: return "Fake Image" else: return "Real Image" else: print("Alert: Input image missing") return "Alert: Input image missing" # Create Gradio interface image_input = gr.Image(type="pil", label="Input Image") text_input = gr.Textbox(label="Caption for image (Optional)") iface = gr.Interface(fn=predict, inputs=[image_input, text_input], outputs=gr.Label(), examples=[["examples/trump-fake.jpeg", "Donald Trump being arrested by authorities."], ["examples/astronaut_space.png", "An astronaut playing basketball with a cat in space, digital art"]]) iface.launch()