File size: 3,832 Bytes
0d7318b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ccaaaa
 
 
0d7318b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
from pathlib import Path
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from model.model import GLPDepth
import open3d as o3d

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
W, H = 512, 512


def load_model(path):
    model = GLPDepth(max_depth=700.0).to(device)
    weights = torch.load(path, map_location=torch.device('cpu'))
    model.load_state_dict(weights)
    model.eval()
    return model


def generate_mesh(dtm, image, image_path):
    # prepare points
    points = np.zeros(shape=(W * H, 3), dtype='float32')
    colors = np.zeros(shape=(W * H, 3), dtype='float32')
    for i in range(H):
        for j in range(W):
            points[i * H + j, 0] = j
            points[i * H + j, 1] = i
            points[i * H + j, 2] = dtm[i, j]
            colors[i * H + j, :3] = image[i, j]
    # point cloud
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    pcd.colors = o3d.utility.Vector3dVector(colors)
    # normals
    pcd.estimate_normals()
    pcd.orient_normals_to_align_with_direction()
    # surface reconstruction
    mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=7, n_threads=1)[0]
    # mesh transformations
    rotation = mesh.get_rotation_matrix_from_xyz((np.pi, 0, 0))
    mesh.rotate(rotation, center=(0, 0, 0))
    mesh.compute_vertex_normals()
    # remove weird artifacts
    mesh.remove_degenerate_triangles()
    mesh.remove_duplicated_triangles()
    mesh.remove_duplicated_vertices()
    mesh.remove_non_manifold_edges()
    # save mesh
    out_path = f'./{image_path.stem}.obj'
    o3d.io.write_triangle_mesh(out_path, mesh)
    return out_path


def predict(image_path):
    image_path = Path(image_path)
    pil_image = Image.open(image_path).convert('L')
    # transform image to torch
    to_tensor = transforms.ToTensor()
    torch_image = to_tensor(pil_image).to(device).unsqueeze(0)
    # model predict
    with torch.no_grad():
        pred_dtm = model(torch_image)
    # transform torch to numpy
    pred_dtm = pred_dtm.squeeze().cpu().detach().numpy()
    pred_dtm = pred_dtm.max() - pred_dtm
    # create 3d model
    image_scaled = np.asarray(pil_image) / 255.0
    obj_path = generate_mesh(pred_dtm, image_scaled, image_path)

    # return correct image
    fig, ax = plt.subplots()
    im = ax.imshow(pred_dtm, cmap='jet', vmin=0, vmax=np.max(pred_dtm))
    plt.colorbar(im, ax=ax)

    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    return [data, obj_path]


model = load_model('best_model.pth')

title = 'Mars DTM Estimation'
description = 'Demo based on my <a href="https://medium.com/towards-data-science/monocular-depth-estimation-to-predict-surface-reliefs-of-mars' \
              '-1b50aed3361a">article</a>. It predicts a DTM from an image of the martian surface. Then, by using a surface reconstruction algorithm, ' \
              'the 3D model is generated and it can also be downloaded. Uploaded images must have a ' \
              '<i>1m/px</i> resolution to get predictions in the correct range. You can download the mesh by clicking ' \
              'the top-right button on the 3D viewer. '
examples = [f'examples/{name}' for name in sorted(os.listdir('examples'))]

iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type='filepath', label='Input Image', sources=['upload', 'clipboard']),
    outputs=[
        gr.Image(label='DTM'),
        gr.Model3D(label='3D Model', clear_color=[0.0, 0.0, 0.0, 0.0])
    ],
    examples=examples,
    allow_flagging='never',
    cache_examples=False,
    title=title,
    description=description
)
iface.launch()