VQA-datamining / src /image_visualization.py
truong-xuan-linh's picture
update visualize
1d3d5c8
raw
history blame contribute delete
No virus
689 Bytes
import matplotlib.pyplot as plt
# Show attention
def plot_attention(img, result, attention_plot, image_dir):
# img = img.numpy().transpose((1, 2, 0))
temp_image = img
fig = plt.figure(figsize=(15, 15))
len_result = len(result)
for l in range(len_result):
temp_att = attention_plot[l][1:].reshape(14, 14)
# temp_att = np.resize(attention_plot[l].detach().numpy(),(98,98))
ax = fig.add_subplot(len_result // 2, len_result // 2, l + 1)
ax.set_title(result[l], fontsize=18)
img = ax.imshow(temp_image)
ax.imshow(temp_att, alpha=0.6, cmap="jet", extent=img.get_extent())
plt.tight_layout()
plt.savefig(image_dir)