File size: 689 Bytes
1d3d5c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import matplotlib.pyplot as plt


# Show attention
def plot_attention(img, result, attention_plot, image_dir):
    # img = img.numpy().transpose((1, 2, 0))
    temp_image = img

    fig = plt.figure(figsize=(15, 15))

    len_result = len(result)
    for l in range(len_result):
        temp_att = attention_plot[l][1:].reshape(14, 14)
        # temp_att = np.resize(attention_plot[l].detach().numpy(),(98,98))
        ax = fig.add_subplot(len_result // 2, len_result // 2, l + 1)
        ax.set_title(result[l], fontsize=18)
        img = ax.imshow(temp_image)
        ax.imshow(temp_att, alpha=0.6, cmap="jet", extent=img.get_extent())

    plt.tight_layout()
    plt.savefig(image_dir)