import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import math from datetime import datetime from sklearn.utils import shuffle from sklearn.model_selection import train_test_split from numpy import save, load, asarray import csv from skimage.io import imread import pickle from sklearn.metrics import accuracy_score import os import time from AffectNetClass import AffectNet from RafdbClass import RafDB from FerPlusClass import FerPlus from config import DatasetName, AffectnetConf, InputDataSize, LearningConfig, DatasetType, RafDBConf, FerPlusConf from cnn_model import CNNModel from custom_loss import CustomLosses from data_helper import DataHelper from dataset_class import CustomDataset class TrainModel: def __init__(self, dataset_name, ds_type, weights='imagenet', lr=1e-3, aug=True): self.dataset_name = dataset_name self.ds_type = ds_type self.weights = weights self.lr = lr self.base_lr = 1e-5 self.max_lr = 5e-4 if dataset_name == DatasetName.fer2013: self.drop = 0.1 self.epochs_drop = 5 if aug: self.img_path = FerPlusConf.aug_train_img_path self.annotation_path = FerPlusConf.aug_train_annotation_path self.masked_img_path = FerPlusConf.aug_train_masked_img_path else: self.img_path = FerPlusConf.no_aug_train_img_path self.annotation_path = FerPlusConf.no_aug_train_annotation_path self.val_img_path = FerPlusConf.test_img_path self.val_annotation_path = FerPlusConf.test_annotation_path self.eval_masked_img_path = FerPlusConf.test_masked_img_path self.num_of_classes = 7 self.num_of_samples = None elif dataset_name == DatasetName.rafdb: self.drop = 0.1 self.epochs_drop = 5 if aug: self.img_path = RafDBConf.aug_train_img_path self.annotation_path = RafDBConf.aug_train_annotation_path self.masked_img_path = RafDBConf.aug_train_masked_img_path else: self.img_path = RafDBConf.no_aug_train_img_path self.annotation_path = RafDBConf.no_aug_train_annotation_path self.val_img_path = RafDBConf.test_img_path self.val_annotation_path = RafDBConf.test_annotation_path self.eval_masked_img_path = RafDBConf.test_masked_img_path self.num_of_classes = 7 self.num_of_samples = None elif dataset_name == DatasetName.affectnet: self.drop = 0.1 self.epochs_drop = 5 if ds_type == DatasetType.train: self.img_path = AffectnetConf.aug_train_img_path self.annotation_path = AffectnetConf.aug_train_annotation_path self.masked_img_path = AffectnetConf.aug_train_masked_img_path self.val_img_path = AffectnetConf.eval_img_path self.val_annotation_path = AffectnetConf.eval_annotation_path self.eval_masked_img_path = AffectnetConf.eval_masked_img_path self.num_of_classes = 8 self.num_of_samples = AffectnetConf.num_of_samples_train elif ds_type == DatasetType.train_7: if aug: self.img_path = AffectnetConf.aug_train_img_path_7 self.annotation_path = AffectnetConf.aug_train_annotation_path_7 self.masked_img_path = AffectnetConf.aug_train_masked_img_path_7 else: self.img_path = AffectnetConf.no_aug_train_img_path_7 self.annotation_path = AffectnetConf.no_aug_train_annotation_path_7 self.val_img_path = AffectnetConf.eval_img_path_7 self.val_annotation_path = AffectnetConf.eval_annotation_path_7 self.eval_masked_img_path = AffectnetConf.eval_masked_img_path_7 self.num_of_classes = 7 self.num_of_samples = AffectnetConf.num_of_samples_train_7 def train(self, arch, weight_path): """""" '''create loss''' c_loss = CustomLosses() '''create summary writer''' summary_writer = tf.summary.create_file_writer( "./train_logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")) start_train_date = datetime.now().strftime("%Y%m%d-%H%M%S") '''making models''' model = self.make_model(arch=arch, w_path=weight_path) '''create save path''' if self.dataset_name == DatasetName.affectnet: save_path = AffectnetConf.weight_save_path + start_train_date + '/' elif self.dataset_name == DatasetName.rafdb: save_path = RafDBConf.weight_save_path + start_train_date + '/' elif self.dataset_name == DatasetName.fer2013: save_path = FerPlusConf.weight_save_path + start_train_date + '/' if not os.path.exists(save_path): os.makedirs(save_path) '''create sample generator''' dhp = DataHelper() ''' Train Generator''' img_filenames, exp_filenames = dhp.create_generator_full_path(img_path=self.img_path, annotation_path=self.annotation_path) '''create dataset''' cds = CustomDataset() ds = cds.create_dataset(img_filenames=img_filenames, anno_names=exp_filenames, is_validation=False) '''create train configuration''' step_per_epoch = len(img_filenames) // LearningConfig.batch_size gradients = None virtual_step_per_epoch = LearningConfig.virtual_batch_size // LearningConfig.batch_size '''create optimizer''' optimizer = tf.keras.optimizers.Adam(self.lr, decay=1e-5) '''start train:''' all_gt_exp = [] all_pr_exp = [] for epoch in range(LearningConfig.epochs): ce_weight = 2 batch_index = 0 for img_batch, exp_batch in ds: '''since the calculation of the confusion matrix will be time-consuming, we only save 1000 labels each time. Moreover, this help us to be more qiuck on updates ''' all_gt_exp, all_pr_exp = self._update_all_labels_arrays(all_gt_exp, all_pr_exp) '''load annotation and images''' '''squeeze''' exp_batch = exp_batch[:, -1] img_batch = img_batch[:, -1, :, :] '''train step''' step_gradients, all_gt_exp, all_pr_exp = self.train_step(epoch=epoch, step=batch_index, total_steps=step_per_epoch, img_batch=img_batch, anno_exp=exp_batch, model=model, optimizer=optimizer, c_loss=c_loss, ce_weight=ce_weight, summary_writer=summary_writer, all_gt_exp=all_gt_exp, all_pr_exp=all_pr_exp) batch_index += 1 '''evaluating part''' global_accuracy, conf_mat, avg_acc = self._eval_model(model=model) '''save weights''' save_name = save_path + '_' + str(epoch) + '_' + self.dataset_name + '_AC_' + str(global_accuracy) model.save(save_name + '.h5') self._save_confusion_matrix(conf_mat, save_name + '.txt') def train_step(self, epoch, step, total_steps, model, ce_weight, img_batch, anno_exp, optimizer, summary_writer, c_loss, all_gt_exp, all_pr_exp): with tf.GradientTape() as tape: pr_data = model([img_batch], training=True) exp_pr_vec = pr_data[0] embeddings = pr_data[1:] bs_size = tf.shape(exp_pr_vec, out_type=tf.dtypes.int64)[0] loss_exp, accuracy = c_loss.cross_entropy_loss(y_pr=exp_pr_vec, y_gt=anno_exp, num_classes=self.num_of_classes, ds_name=self.dataset_name) '''Feature difference loss''' # embedding_similarity_loss = 0 embedding_similarity_loss = c_loss.embedding_loss_distance(embeddings=embeddings) '''update confusion matrix''' exp_pr = tf.constant([np.argmax(exp_pr_vec[i]) for i in range(bs_size)], dtype=tf.dtypes.int64) tr_conf_matrix, all_gt_exp, all_pr_exp = c_loss.update_confusion_matrix(anno_exp, # real labels exp_pr, # real labels all_gt_exp, all_pr_exp) ''' correlation between the embeddings''' correlation_loss = c_loss.correlation_loss_multi(embeddings=embeddings, exp_gt_vec=anno_exp, exp_pr_vec=exp_pr_vec, tr_conf_matrix=tr_conf_matrix) '''mean loss''' mean_correlation_loss = c_loss.mean_embedding_loss_distance(embeddings=embeddings, exp_gt_vec=anno_exp, exp_pr_vec=exp_pr_vec, num_of_classes=self.num_of_classes) lamda_param = 50 loss_total = lamda_param * loss_exp + \ embedding_similarity_loss + \ correlation_loss + \ mean_correlation_loss # '''calculate gradient''' gradients_of_model = tape.gradient(loss_total, model.trainable_variables) # '''apply Gradients:''' optimizer.apply_gradients(zip(gradients_of_model, model.trainable_variables)) # '''printing loss Values: ''' tf.print("->EPOCH: ", str(epoch), "->STEP: ", str(step) + '/' + str(total_steps), ' -> : accuracy: ', accuracy, ' -> : loss_total: ', loss_total, ' -> : loss_exp: ', loss_exp, ' -> : embedding_similarity_loss: ', embedding_similarity_loss, ' -> : correlation_loss: ', correlation_loss, ' -> : mean_correlation_loss: ', mean_correlation_loss) with summary_writer.as_default(): tf.summary.scalar('loss_total', loss_total, step=epoch) tf.summary.scalar('loss_exp', loss_exp, step=epoch) tf.summary.scalar('correlation_loss', correlation_loss, step=epoch) tf.summary.scalar('mean_correlation_loss', mean_correlation_loss, step=epoch) tf.summary.scalar('embedding_similarity_loss', embedding_similarity_loss, step=epoch) return gradients_of_model, all_gt_exp, all_pr_exp def train_step_old(self, epoch, step, total_steps, model, ce_weight, img_batch, anno_exp, optimizer, summary_writer, c_loss, all_gt_exp, all_pr_exp): with tf.GradientTape() as tape: # '''create annotation_predicted''' # exp_pr, embedding = model([img_batch], training=True) exp_pr_vec, embedding_class, embedding_mean, embedding_var = model([img_batch], training=True) bs_size = tf.shape(exp_pr_vec, out_type=tf.dtypes.int64)[0] # # '''CE loss''' loss_exp, accuracy = c_loss.cross_entropy_loss(y_pr=exp_pr_vec, y_gt=anno_exp, num_classes=self.num_of_classes, ds_name=self.dataset_name) # loss_cls_mean, loss_cls_var, loss_mean_var = c_loss.embedding_loss_distance( embedding_class=embedding_class, embedding_mean=embedding_mean, embedding_var=embedding_var, bs_size=bs_size) feature_diff_loss = loss_cls_mean + loss_cls_var + loss_mean_var # correlation between the class_embeddings cor_loss, all_gt_exp, all_pr_exp = c_loss.correlation_loss(embedding=embedding_class, # distribution exp_gt_vec=anno_exp, exp_pr_vec=exp_pr_vec, num_of_classes=self.num_of_classes, all_gt_exp=all_gt_exp, all_pr_exp=all_pr_exp) # correlation between the mean_emb_cor_loss mean_emb_cor_loss, mean_emb_kl_loss = c_loss.mean_embedding_loss(embedding=embedding_mean, exp_gt_vec=anno_exp, exp_pr_vec=exp_pr_vec, num_of_classes=self.num_of_classes) mean_loss = mean_emb_cor_loss + 10 * mean_emb_kl_loss var_emb_cor_loss, var_emb_kl_loss = c_loss.variance_embedding_loss(embedding=embedding_var, exp_gt_vec=anno_exp, exp_pr_vec=exp_pr_vec, num_of_classes=self.num_of_classes) var_loss = var_emb_cor_loss + 10 * var_emb_kl_loss # '''total:''' loss_total = 100 * loss_exp + cor_loss + 10 * feature_diff_loss + mean_loss + var_loss # '''calculate gradient''' gradients_of_model = tape.gradient(loss_total, model.trainable_variables) # '''apply Gradients:''' optimizer.apply_gradients(zip(gradients_of_model, model.trainable_variables)) # '''printing loss Values: ''' tf.print("->EPOCH: ", str(epoch), "->STEP: ", str(step) + '/' + str(total_steps), ' -> : accuracy: ', accuracy, ' -> : loss_total: ', loss_total, ' -> : loss_exp: ', loss_exp, ' -> : cor_loss: ', cor_loss, ' -> : feature_loss: ', feature_diff_loss, ' -> : mean_loss: ', mean_loss, ' -> : var_loss: ', var_loss) with summary_writer.as_default(): tf.summary.scalar('loss_total', loss_total, step=epoch) tf.summary.scalar('loss_exp', loss_exp, step=epoch) tf.summary.scalar('loss_correlation', cor_loss, step=epoch) return gradients_of_model, all_gt_exp, all_pr_exp def _eval_model(self, model): """""" '''first we need to create the 4 bunch here: ''' '''for Affectnet, we need to calculate accuracy of each label and then total avg accuracy:''' global_accuracy = 0 avg_acc = 0 conf_mat = [] if self.dataset_name == DatasetName.affectnet: if self.ds_type == DatasetType.train: affn = AffectNet(ds_type=DatasetType.eval) else: affn = AffectNet(ds_type=DatasetType.eval_7) global_accuracy, conf_mat, avg_acc, precision, recall, fscore, support = \ affn.test_accuracy(model=model) elif self.dataset_name == DatasetName.rafdb: rafdb = RafDB(ds_type=DatasetType.test) global_accuracy, conf_mat, avg_acc, precision, recall, fscore, support = rafdb.test_accuracy(model=model) elif self.dataset_name == DatasetName.fer2013: ferplus = FerPlus(ds_type=DatasetType.test) global_accuracy, conf_mat, avg_acc, precision, recall, fscore, support = ferplus.test_accuracy(model=model) print("================== global_accuracy =====================") print(global_accuracy) print("================== Average Accuracy =====================") print(avg_acc) print("================== Confusion Matrix =====================") print(conf_mat) return global_accuracy, conf_mat, avg_acc def make_model(self, arch, w_path): cnn = CNNModel() model = cnn.get_model(arch=arch, num_of_classes=LearningConfig.num_classes, weights=self.weights) if w_path is not None: model.load_weights(w_path) return model def _save_confusion_matrix(self, conf_mat, save_name): f = open(save_name, "a") print(save_name) f.write(np.array_str(conf_mat)) f.close() def _update_all_labels_arrays(self, all_gt_exp, all_pr_exp): if len(all_gt_exp) < LearningConfig.labels_history_frame: return all_gt_exp, all_pr_exp else: # remove the first batch: return all_gt_exp[LearningConfig.batch_size:], all_pr_exp[LearningConfig.batch_size:]