geoguessr-bot / app.py
manu's picture
Update app.py
18d470c
import gradio as gr
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import DBSCAN
from shapely.geometry import Point
import geopandas as gpd
from geopandas import GeoDataFrame
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
model.eval()
metadata = pd.read_csv("metadatav3.csv")
metadata.path = metadata.path.apply(lambda x: x.split("/")[-1])
embeddings = np.load("embeddings.npy")
test_embeddings = np.load("test_embeddings.npy")
files = open("files.txt").read().split("\n")
test_files = open("test_files.txt").read().split("\n")
print(embeddings.shape, test_embeddings.shape, len(files), len(test_files))
knn = NearestNeighbors(n_neighbors=50, algorithm='kd_tree', n_jobs=8)
knn.fit(embeddings)
# %%
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
transform = T.Compose([
T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
def cluster(df, eps=0.1, min_samples=5, metric="cosine", n_jobs=8, show=False):
if len(df) == 1:
return df
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
dbscan.fit(df[["longitude", "latitude"]])
df["cluster"] = dbscan.labels_
# Return centroid of the cluster with the most points
df = df[df.cluster == df.cluster.value_counts().index[0]]
df = df.groupby("cluster").apply(lambda x: x[["longitude", "latitude"]].median()).reset_index()
# Return coordinates of the cluster with the most points
return df.longitude.iloc[0], df.latitude.iloc[0]
def guess_image(img):
# img = Image.open(image_path)
# cast as rgb
img = img.convert('RGB')
print(img)
with torch.no_grad():
features = model(transform(img).unsqueeze(0))[0].cpu()
distances, neighbors = knn.kneighbors(features.unsqueeze(0))
neighbors = neighbors[0]
# Return metadata df rows with neighbors
df = pd.DataFrame()
for n in neighbors:
df = pd.concat([df, metadata[metadata.path == files[n]]])
coords = cluster(df, eps=0.005, min_samples=5)
geometry = [Point(xy) for xy in zip(df['longitude'], df['latitude'])]
gdf = GeoDataFrame(df, geometry=geometry)
gdf_guess = GeoDataFrame(df[:1], geometry=[Point(coords)])
# this is a simple map that goes with geopandas
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
plot_ = world.plot(figsize=(10, 6))
gdf.plot(ax=plot_, marker='o', color='red', markersize=15)
gdf_guess.plot(ax=plot_, marker='o', color='blue', markersize=15);
return coords, plot_.figure
# Image to image translation
def translate_image(input_image):
coords, fig = guess_image(Image.fromarray(input_image.astype('uint8'), 'RGB'))
fig.savefig("tmp.png")
return str(coords), np.array(Image.open("tmp.png").convert("RGB"))
demo = gr.Interface(fn=translate_image, inputs="image", outputs=["text", "image"], title="Street View Location", description="Helps you guess the location of a street view image ! Use it on square images with no goole maps artefacts when possible !")
if __name__ == "__main__":
demo.launch()