File size: 2,167 Bytes
1cc0005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import csv
import matplotlib.pyplot as plt
import numpy as np

def _decode_train_csv(csv_path):
    
    epochs = []
    train_loss = []
    val_loss = []
    dice = []

    with open(csv_path) as csv_file:
        csv_reader = csv.DictReader(csv_file)
        for row in csv_reader:
            epochs.append(row['step'])
            train_loss.append(row['train_loss'])
            val_loss.append(row['val_loss'])
            dice.append(row['dice_score'])
    
    return (np.array(epochs, dtype=np.uint), np.array(train_loss, dtype=np.float32), 
    np.array(val_loss, dtype=np.float32), np.array(dice, dtype=np.float32))

def plot_train_data(csv_path, store = None, show=True, steps_in_epoch = -1):
    data = _decode_train_csv(csv_path)

    plt.plot(data[0], data[1], label = 'Training Loss')
    plt.plot(data[0], data[2], label = 'Validation loss')
    plt.plot(data[0], data[3], label = 'Dice Score')

    
    if(steps_in_epoch > 0):
        vlines = [x for x in range(0, data[0][-1]) if x % steps_in_epoch == 0]
        plt.vlines(vlines, ymin = -0.2, ymax = -0.05)

    plt.ylim(-0.1, 1.1)
    plt.ylabel('Training loss')
    plt.xlabel('Train Step')
    plt.legend(loc="upper left")
    if(store):
        plt.savefig(store)
    if(show):
        plt.show()

def plot_multiple_val_losses(names, csvs):
    for name, csv in zip(names, csvs):
        data = _decode_train_csv(csv)
        plt.plot(data[0], data[2], label = name)

    plt.ylim(-0.1, 1.1)
    plt.xlim(0, 7000)
    plt.ylabel('Validation Loss')
    plt.xlabel('Train Step')
    plt.legend(loc="upper left")
    plt.show()

if __name__ == "__main__":
    #path = "D:\\Repos\\LungTumorSegmentation\\models\\metrics.csv"
    #plot_train_data(path)
    names = ['Base 16: Multiplier: 2x', 'Base 64: Multiplier: 2x', 'Base 128: Multiplier: 2x', 'Base 64: Multiplier: 3.5x', 'Base 192: Multiplier: 1.5x']
    csvs = ["C:\\Users\\vemun\\Desktop\\Plots\\16_2.csv", "C:\\Users\\vemun\\Desktop\\Plots\\64_2.csv", "C:\\Users\\vemun\\Desktop\\Plots\\128_2.csv", "C:\\Users\\vemun\\Desktop\\Plots\\64_3_5.csv", "C:\\Users\\vemun\\Desktop\\Plots\\192_1_5.csv"]
    plot_multiple_val_losses(names, csvs)