Early-Stopping bei Verschlechterung in der Validierung
Mit Hilfer einer vor-definierte Callback-Funktion kann das Training gestoppt werden, sobald sich der Loss bei der Validierung verschlechter. Im Beispiel ist mit patience=5 außerdem eingestellt, dass noch 5 Epochen nach einer Verschlechterung weitertrainiert wird um zu prüfen, ob sich der Validierungsloss vielleicht doch wieder verbessert. Wenn das nicht der Fall ist, wird das Training abgebrochen, aber die Gewichte beibehalten, die den geringsten Validierungsloss hatten.
model.fit(train_batches, epochs=1000, validation_data=validation_batches,
callbacks=[EarlyStopping(patience=5, min_delta=0.05, baseline=0.8, mode='min', monitor='val_loss',restore_best_weights=True, verbose=1)
])
Visualisierung von loss und accuracy nach dem Training
Damit auch die Accuracy angezeigt werden kann, muss das neuronale Netz dies als Metrik erfassen, d.h. die compile()-Methode muss den entsprechenden metrics-Parameter enthalten. Außerdem muss die Trainingshistorie in einer Variablen übergeben werden.
import matplotlib.pyplot as plt
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) #compile() braucht metrics=['accuracy'] - Parameter
history = model.fit(x, y, epochs=50)
def plot_loss_acc(history):
'''Plots the training and validation loss and accuracy from a history object'''
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
fig, ax = plt.subplots(1,2, figsize=(12, 6))
ax[0].plot(epochs, acc, 'bo', label='Training accuracy')
ax[0].plot(epochs, val_acc, 'b', label='Validation accuracy')
ax[0].set_title('Training and validation accuracy')
ax[0].set_xlabel('epochs')
ax[0].set_ylabel('accuracy')
ax[0].legend()
ax[1].plot(epochs, loss, 'bo', label='Training Loss')
ax[1].plot(epochs, val_loss, 'b', label='Validation Loss')
ax[1].set_title('Training and validation loss')
ax[1].set_xlabel('epochs')
ax[1].set_ylabel('loss')
ax[1].legend()
plt.show()
plot_loss_acc(history)