File size: 743 Bytes
e8ac98a
 
 
 
 
 
 
ea238c4
e8ac98a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

from warnings import filterwarnings
filterwarnings('ignore')


""" Trainer """
def train(dict configuration, X_train, y_train, X_test, y_test):
	cdef object early_stopping = EarlyStopping(
		monitor = 'val_loss',
		patience = 5,
		mode = 'min'
	)

	cdef object model_checkpoint = ModelCheckpoint(
		filepath = configuration['model_file'],
		save_best_only = True,
		monitor = 'val_loss',
		mode = 'min'
	)

	cdef object history = configuration['model'].fit(
		X_train, y_train,
		epochs = configuration['epochs'],
		batch_size = configuration['batch_size'],
		validation_data = (X_test, y_test),
		callbacks = [ early_stopping, model_checkpoint ]
	)

	return history