import gradio as gr import torch import numpy as np from PIL import Image import torchvision.transforms as T import pandas as pd from sklearn.neighbors import NearestNeighbors from sklearn.cluster import DBSCAN from shapely.geometry import Point import geopandas as gpd from geopandas import GeoDataFrame model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') model.eval() metadata = pd.read_csv("metadatav3.csv") metadata.path = metadata.path.apply(lambda x: x.split("/")[-1]) embeddings = np.load("embeddings.npy") test_embeddings = np.load("test_embeddings.npy") files = open("files.txt").read().split("\n") test_files = open("test_files.txt").read().split("\n") print(embeddings.shape, test_embeddings.shape, len(files), len(test_files)) knn = NearestNeighbors(n_neighbors=50, algorithm='kd_tree', n_jobs=8) knn.fit(embeddings) # %% IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) transform = T.Compose([ T.Resize(256, interpolation=T.InterpolationMode.BICUBIC), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ]) def cluster(df, eps=0.1, min_samples=5, metric="cosine", n_jobs=8, show=False): if len(df) == 1: return df dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs) dbscan.fit(df[["longitude", "latitude"]]) df["cluster"] = dbscan.labels_ # Return centroid of the cluster with the most points df = df[df.cluster == df.cluster.value_counts().index[0]] df = df.groupby("cluster").apply(lambda x: x[["longitude", "latitude"]].median()).reset_index() # Return coordinates of the cluster with the most points return df.longitude.iloc[0], df.latitude.iloc[0] def guess_image(img): # img = Image.open(image_path) # cast as rgb img = img.convert('RGB') print(img) with torch.no_grad(): features = model(transform(img).unsqueeze(0))[0].cpu() distances, neighbors = knn.kneighbors(features.unsqueeze(0)) neighbors = neighbors[0] # Return metadata df rows with neighbors df = pd.DataFrame() for n in neighbors: df = pd.concat([df, metadata[metadata.path == files[n]]]) coords = cluster(df, eps=0.005, min_samples=5) geometry = [Point(xy) for xy in zip(df['longitude'], df['latitude'])] gdf = GeoDataFrame(df, geometry=geometry) gdf_guess = GeoDataFrame(df[:1], geometry=[Point(coords)]) # this is a simple map that goes with geopandas world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres')) plot_ = world.plot(figsize=(10, 6)) gdf.plot(ax=plot_, marker='o', color='red', markersize=15) gdf_guess.plot(ax=plot_, marker='o', color='blue', markersize=15); return coords, plot_.figure # Image to image translation def translate_image(input_image): coords, fig = guess_image(Image.fromarray(input_image.astype('uint8'), 'RGB')) fig.savefig("tmp.png") return str(coords), np.array(Image.open("tmp.png").convert("RGB")) demo = gr.Interface(fn=translate_image, inputs="image", outputs=["text", "image"], title="Street View Location", description="Helps you guess the location of a street view image ! Use it on square images with no goole maps artefacts when possible !") if __name__ == "__main__": demo.launch()