Cheat Sheet: Hilfreiche Tensorflow-Code-Snippets

Early-Stopping bei Verschlechterung in der Validierung

Mit Hilfer einer vor-definierte Callback-Funktion kann das Training gestoppt werden, sobald sich der Loss bei der Validierung verschlechter. Im Beispiel ist mit patience=5 außerdem eingestellt, dass noch 5 Epochen nach einer Verschlechterung weitertrainiert wird um zu prüfen, ob sich der Validierungsloss vielleicht doch wieder verbessert. Wenn das nicht der Fall ist, wird das Training abgebrochen, aber die Gewichte beibehalten, die den geringsten Validierungsloss hatten.

model.fit(train_batches, epochs=1000, validation_data=validation_batches, 
callbacks=[EarlyStopping(patience=5, min_delta=0.05, baseline=0.8, mode='min', monitor='val_loss',restore_best_weights=True, verbose=1)
])

Visualisierung von loss und accuracy nach dem Training

Damit auch die Accuracy angezeigt werden kann, muss das neuronale Netz dies als Metrik erfassen, d.h. die compile()-Methode muss den entsprechenden metrics-Parameter enthalten. Außerdem muss die Trainingshistorie in einer Variablen übergeben werden.

import matplotlib.pyplot as plt

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) #compile() braucht  metrics=['accuracy'] - Parameter 

history = model.fit(x, y, epochs=50)

def plot_loss_acc(history):
  '''Plots the training and validation loss and accuracy from a history object'''
  acc = history.history['accuracy']
  val_acc = history.history['val_accuracy']
  loss = history.history['loss']
  val_loss = history.history['val_loss']

  epochs = range(len(acc))

  fig, ax = plt.subplots(1,2, figsize=(12, 6))
  ax[0].plot(epochs, acc, 'bo', label='Training accuracy')
  ax[0].plot(epochs, val_acc, 'b', label='Validation accuracy')
  ax[0].set_title('Training and validation accuracy')
  ax[0].set_xlabel('epochs')
  ax[0].set_ylabel('accuracy')
  ax[0].legend()

  ax[1].plot(epochs, loss, 'bo', label='Training Loss')
  ax[1].plot(epochs, val_loss, 'b', label='Validation Loss')
  ax[1].set_title('Training and validation loss')
  ax[1].set_xlabel('epochs')
  ax[1].set_ylabel('loss')
  ax[1].legend()

  plt.show()

plot_loss_acc(history)
Nach oben scrollen