Early stopping model training in Keras

What is early stopping?

So, what is exactly early stopping? Early stopping is a method that can be used to prevent overfitting in machine learning. By stopping model training when the performance on a validation set stopped improving we can select the number of epochs needed for optimal training.

How can you use it?

The basic implementation of early stopping works by splitting your training data into 2 new sets:

  • Training set
  • Validation set

Usually 70% – 80% of the original training data is used for the training set. The remainder is used for the validation set. The model is than trained on only the training set and evaluated every epoch on the validation set. The training should be stopped if the performance on the validation set does not improve for a certain number of epochs. In Keras this certain number of epochs can be specified by setting the patience value. You can then use the stored model with the best performance on the validation set. Alternatively, you can use the determined number of epochs required for the best performance to once again train the model but now on the complete original training data.

The same applies when using cross-validation. We can use either the trained best performing models for each fold, or for each fold we determine the optimal number of epochs, average them, and then train the model again on the complete original training data for the averaged number of epochs.

Implementing Early Stopping

Let’s take a look at how we can implement early stopping with the Keras framework in Python. The Python script file EarlyStoppingModelTraining.py used in this blog can be found in Github. Again Tensorflow is used as the back-end for Keras.

First let’s import all required packages.

import numpy as np 
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt

I’ll create a dictionary object in which the model metrics will be stored.

# Store Model metrics
history = {}

Next I define a function to plot a chart and save it to file. The chart contains the metrics for both the training and validation set.

# Plot Chart
def plot_chart_to_file(best_epoch, best_value):

    # Plot Chart
    fig = plt.figure(dpi=300)

    # Subplot for Accuracy
    ax1 = fig.add_subplot(111)    
    ax1.plot(history.history['acc'], color='b', 
    label='Train Accuracy')
    ax1.plot(history.history['val_acc'], color='g', 
    label='Validation Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend(loc='lower left', bbox_to_anchor=(0, -0.30))
    
    # Subplot for Loss
    ax2 = ax1.twinx()
    ax2.plot(history.history['loss'], color='y', 
    label='Train Loss')
    ax2.plot(history.history['val_loss'], color='c', 
    label='Validation Loss')
    ax2.plot(best_epoch, best_value, 'r+', label='Best Model')
    ax2.set_ylabel('Loss')   
    ax2.legend(loc='lower right', bbox_to_anchor=(1, -0.30))
    
    # Set Title
    plt.title('Model - Best Epoch (' + str(best_epoch) + ')')
    
    # .. and save..
    plt.savefig('Blog3_Model_Chart.png', 
    bbox_inches="tight")   

I define a function to create the neural network model.

# Create Model
def create_model():
    model = Sequential()
    model.add(Dense(784, input_dim=784, kernel_initializer='normal', 
    activation='relu'))
    model.add(Dense(10, kernel_initializer='normal', 
    activation='softmax'))
    # Compile model
    model.compile(loss='categorical_crossentropy', 
    optimizer=Adam(lr=0.0005), metrics=['accuracy'])
    return model

Next is the necessary code for downloading the MNIST dataset, the one-hot encoding of the labels, splitting into training and validation sets and creating the neural network model.

# Download MNIST Data
mnist = fetch_mldata('MNIST original', data_home='~')

# Rescale
X = mnist.data.astype(np.float32) / 255

# One Hot
labels = range(10)
lb = preprocessing.LabelBinarizer()
lb.fit(labels)
Y = lb.transform(mnist.target.astype(int))

# Split in Training and Validation Sets
x_train, x_val, y_train, y_val = train_test_split(X, Y, 
test_size=0.15, random_state=42, stratify=Y)
    
# Create Model
model = create_model()

Finally we can implement the code used for the early stopping. In Keras the early stopping is implemented as a callback function. The callback for early stopping is applied at the end of each epoch.

The code for early stopping is implemented to monitor the validation loss and the mode is set to min. This way it will stop the training when the metric that is monitored has stopped decreasing for the number of epochs specified by the patience. The patience is set to 25. I chose a large number to be able to show more epochs in the charts. In other scenarios you could likely pick a smaller number. The verbose is set to provide basic logging.

# Configure Early Stopping
patience = 25
earlystop = EarlyStopping(monitor="val_loss",
                          min_delta=0,
                          mode="min",
                          verbose=1, 
                          patience=patience)

If we use the code like this the model training will early stop, however the trained model will be from the last epoch when the training stopped …… and that is not the best performing model since the validation loss increased already for 25 epochs (as specified by the patience).

Let’s implement another callback function to make sure that we save the best performing model to file. The ModelCheckpoint saves the model to the specified file after every epoch. It monitors the same metric as early stopping. By setting save_best_only=True we only save the best model.

# Configure Checkpoint
checkpoint = ModelCheckpoint('modelweights.hdf5', 
                              monitor='val_loss', 
                              verbose=1, 
                              save_best_only=True, 
                              mode='min')

Finally we specify a list with the callback operations and use them when calling fit on the model.

# Callback List
callbacks = [checkpoint, earlystop]

# Fit Model 
history = model.fit(x_train, y_train, validation_data=(x_val, y_val), 
epochs=100, batch_size=192, verbose=2, shuffle=False, callbacks=callbacks)

If you want to reload the best model again after training and further use that model you can do that with the code below. This will load the model as it was earlier saved by the implementation of ModelCheckpoint.

# Reload Best Model for further usage....
model.load_weights('modelweights.hdf5')

As a last step the function to plot and save the chart is called.

# Plot Charts
plot_chart_to_file(earlystop.stopped_epoch - patience, earlystop.best)

The results

The code was first run with early stopping configured to monitor for the minimization of validation loss. The patience was set to 25. The best validation loss was found at epoch 14. Below you can see the chart with the model metrics.Early Stopping based on Loss
When using the verbose logging you can see that at epoch 15 the validation loss still improved. In epoch 16 the validation loss slightly increased so it did not improve any more. 25 epochs later the training was early stopped at epoch 40 because of the configured patience.

Note that the verbose logging starts counting at epoch 1. Internally in the Keras code the counting starts at epoch 0…this is why the chart show epoch 14 as the best one.
Training Verbose Logging
A second run of the code was performed with the early stopping and the model checkpoint configured to use the validation accuracy as metric. The mode was set to max because for accuracy we want to monitor when it stopped increasing. The patience was set to 20. The best validation accuracy was found at epoch 15. Below you can see the chart with the model metrics.
Early Stopping base on Accuracy

Summary

In this blog I gave a brief description what early stopping is and how you can use it to your advantage. We looked at what needs to be implemented in the python script to implement and configure early stopping and how to save the best trained model to file.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s