manu commited on
Commit
3df141a
1 Parent(s): 0348065

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision.transforms as T
6
+ import pandas as pd
7
+ from sklearn.neighbors import NearestNeighbors
8
+ from sklearn.cluster import DBSCAN
9
+ from shapely.geometry import Point
10
+ import geopandas as gpd
11
+ from geopandas import GeoDataFrame
12
+
13
+ model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to("cuda")
14
+ model.eval()
15
+
16
+ metadata = pd.read_csv("data/streetview_v3/metadatav3.csv")
17
+ metadata.path = metadata.path.apply(lambda x: x.split("/")[-1])
18
+
19
+ PATH = "data/streetview_v3/images/"
20
+ PATH_TEST = "data/test-competition/images/images/"
21
+
22
+ embeddings = np.load("data/embeddings.npy")
23
+ test_embeddings = np.load("data/test_embeddings.npy")
24
+ files = open("data/files.txt").read().split("\n")
25
+ test_files = open("data/test_files.txt").read().split("\n")
26
+ print(embeddings.shape, test_embeddings.shape, len(files), len(test_files))
27
+
28
+ knn = NearestNeighbors(n_neighbors=50, algorithm='kd_tree', n_jobs=8)
29
+ knn.fit(embeddings)
30
+
31
+ # %%
32
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
33
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
34
+
35
+ transform = T.Compose([
36
+ T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
37
+ T.CenterCrop(224),
38
+ T.ToTensor(),
39
+ T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
40
+ ])
41
+
42
+
43
+ def cluster(df, eps=0.1, min_samples=5, metric="cosine", n_jobs=8, show=False):
44
+ if len(df) == 1:
45
+ return df
46
+ dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric, n_jobs=n_jobs)
47
+ dbscan.fit(df[["longitude", "latitude"]])
48
+ df["cluster"] = dbscan.labels_
49
+ # Return centroid of the cluster with the most points
50
+ df = df[df.cluster == df.cluster.value_counts().index[0]]
51
+ df = df.groupby("cluster").apply(lambda x: x[["longitude", "latitude"]].median()).reset_index()
52
+ # Return coordinates of the cluster with the most points
53
+ return df.longitude.iloc[0], df.latitude.iloc[0]
54
+
55
+
56
+ def guess_image(img):
57
+ # img = Image.open(image_path)
58
+ # cast as rgb
59
+ img = img.convert('RGB')
60
+ print(img)
61
+ with torch.no_grad():
62
+ features = model(transform(img).to("cuda").unsqueeze(0))[0].cpu()
63
+ distances, neighbors = knn.kneighbors(features.unsqueeze(0))
64
+
65
+ neighbors = neighbors[0]
66
+ # Return metadata df rows with neighbors
67
+ df = pd.DataFrame()
68
+ for n in neighbors:
69
+ df = pd.concat([df, metadata[metadata.path == files[n]]])
70
+ coords = cluster(df, eps=0.005, min_samples=5)
71
+
72
+ geometry = [Point(xy) for xy in zip(df['longitude'], df['latitude'])]
73
+ gdf = GeoDataFrame(df, geometry=geometry)
74
+ gdf_guess = GeoDataFrame(df[:1], geometry=[Point(coords)])
75
+ # this is a simple map that goes with geopandas
76
+ world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
77
+ plot_ = world.plot(figsize=(10, 6))
78
+ gdf.plot(ax=plot_, marker='o', color='red', markersize=15)
79
+ gdf_guess.plot(ax=plot_, marker='o', color='blue', markersize=15);
80
+ return coords, plot_.figure
81
+
82
+
83
+ # Image to image translation
84
+ def translate_image(input_image):
85
+ coords, fig = guess_image(Image.fromarray(input_image.astype('uint8'), 'RGB'))
86
+ fig.savefig("tmp.png")
87
+ return str(coords), np.array(Image.open("tmp.png").convert("RGB"))
88
+
89
+
90
+ demo = gr.Interface(fn=translate_image, inputs="image", outputs=["text", "image"], title="Street View Location")
91
+
92
+ if __name__ == "__main__":
93
+ demo.launch()