Vemund Fredriksen
Add training pipeline (#21)
1cc0005 unverified
raw
history blame
No virus
2.17 kB
import csv
import matplotlib.pyplot as plt
import numpy as np
def _decode_train_csv(csv_path):
epochs = []
train_loss = []
val_loss = []
dice = []
with open(csv_path) as csv_file:
csv_reader = csv.DictReader(csv_file)
for row in csv_reader:
epochs.append(row['step'])
train_loss.append(row['train_loss'])
val_loss.append(row['val_loss'])
dice.append(row['dice_score'])
return (np.array(epochs, dtype=np.uint), np.array(train_loss, dtype=np.float32),
np.array(val_loss, dtype=np.float32), np.array(dice, dtype=np.float32))
def plot_train_data(csv_path, store = None, show=True, steps_in_epoch = -1):
data = _decode_train_csv(csv_path)
plt.plot(data[0], data[1], label = 'Training Loss')
plt.plot(data[0], data[2], label = 'Validation loss')
plt.plot(data[0], data[3], label = 'Dice Score')
if(steps_in_epoch > 0):
vlines = [x for x in range(0, data[0][-1]) if x % steps_in_epoch == 0]
plt.vlines(vlines, ymin = -0.2, ymax = -0.05)
plt.ylim(-0.1, 1.1)
plt.ylabel('Training loss')
plt.xlabel('Train Step')
plt.legend(loc="upper left")
if(store):
plt.savefig(store)
if(show):
plt.show()
def plot_multiple_val_losses(names, csvs):
for name, csv in zip(names, csvs):
data = _decode_train_csv(csv)
plt.plot(data[0], data[2], label = name)
plt.ylim(-0.1, 1.1)
plt.xlim(0, 7000)
plt.ylabel('Validation Loss')
plt.xlabel('Train Step')
plt.legend(loc="upper left")
plt.show()
if __name__ == "__main__":
#path = "D:\\Repos\\LungTumorSegmentation\\models\\metrics.csv"
#plot_train_data(path)
names = ['Base 16: Multiplier: 2x', 'Base 64: Multiplier: 2x', 'Base 128: Multiplier: 2x', 'Base 64: Multiplier: 3.5x', 'Base 192: Multiplier: 1.5x']
csvs = ["C:\\Users\\vemun\\Desktop\\Plots\\16_2.csv", "C:\\Users\\vemun\\Desktop\\Plots\\64_2.csv", "C:\\Users\\vemun\\Desktop\\Plots\\128_2.csv", "C:\\Users\\vemun\\Desktop\\Plots\\64_3_5.csv", "C:\\Users\\vemun\\Desktop\\Plots\\192_1_5.csv"]
plot_multiple_val_losses(names, csvs)