35

I am trying to train a LSTM model. Is this model suffering from overfitting?

Here is train and validation loss graph:

loss

Ethan
  • 1,657
  • 9
  • 25
  • 39
DukeLover
  • 601
  • 1
  • 7
  • 15

6 Answers6

41

The model is overfitting right from epoch 10, the validation loss is increasing while the training loss is decreasing.

Dealing with such a Model:

  1. Data Preprocessing: Standardizing and Normalizing the data.
  2. Model compelxity: Check if the model is too complex. Add dropout, reduce number of layers or number of neurons in each layer.
  3. Learning Rate and Decay Rate: Reduce the learning rate, a good starting value is usually between 0.0005 to 0.001. Also consider a decay rate of 1e-6.

There are many other options as well to reduce overfitting, assuming you are using Keras, visit this link.

user5722540
  • 675
  • 4
  • 11
9

Yes this is an overfitting problem since your curve shows point of inflection. This is a sign of very large number of epochs. In this case, model could be stopped at point of inflection or the number of training examples could be increased.

Also, Overfitting is also caused by a deep model over training data. In that case, you'll observe divergence in loss between val and train very early.

Mohit Banerjee
  • 361
  • 1
  • 2
5

Another possible cause of overfitting is improper data augmentation. If you're augmenting then make sure it's really doing what you expect.

I had a similar problem, and it turned out to be due to a bug in my Tensorflow data pipeline where I was augmenting before caching:

    def get_dataset(inputfile, batchsize):
        # Load the data into a TensorFlow dataset.
        signals, labels = read_data_from_file(inputfile)
        dataset = tf.data.Dataset.from_tensor_slices((signals, labels))
    # Augment the data by dynamically tweaking each training sample on the fly.
    dataset = dataset.map(
                map_func=(lambda signals, labels: (tuple(tf.py_function(func=augment, inp=[signals], Tout=[tf.float32])), labels)))

    # Oops! Should have called cache() before augmenting
    dataset = dataset.cache()
    dataset = ... # Shuffle, repeat, batch, etc.
    return dataset

training_data = get_dataset("training.txt", 32)
val_data = //...

model.fit(training_data, validation_data=val_data, ...)

As a result, the training data was only being augmented for the first epoch. This caused the model to quickly overfit on the training data. Moving the augment call after cache() solved the problem.

Kevin D.
  • 151
  • 1
  • 5
5

I had this issue - while training loss was decreasing, the validation loss was not decreasing. I checked and found while I was using LSTM:

  • I simplified the model - instead of 20 layers, I opted for 8 layers.
  • Instead of scaling within range (-1,1), I choose (0,1), this right there reduced my validation loss by the magnitude of one order
  • I reduced the batch size from 500 to 50 (just trial and error)
  • I added more features, which I thought intuitively would add some new intelligent information to the X->y pair
yogender
  • 191
  • 1
  • 1
3

It may be that you need to feed in more data, as well. If the model overfits, your dataset may be so small that the high capacity of the model makes it easily fit this small dataset, while not delivering out-of-sample performance. In other words, it does not learn a robust representation of the true underlying data distribution, just a representation that fits the training data very well.

Solutions as stated above:

  • reduce model complexity: if you feel your model is not really overly complex, you should try running on a larger dataset, at first.
  • regularization: using dropout and other regularization techniques may assist the model in generalizing better.

I propose to extend your dataset (largely), which will be costly in terms of several aspects obviously, but it will also serve as a form of "regularization" and give you a more confident answer. In case you cannot gather more data, think about clever ways to augment your dataset by applying transforms, adding noise, etc to the input data (or to the network output).

user101893
  • 221
  • 1
  • 3
1

It's not severe overfitting. So, here is my suggestions:

1- Simplify your network! Maybe your network is too complex for your data. If you have a small dataset or features are easy to detect, you don't need a deep network.

2- Add Dropout layers.

3- Use weight regularization. Here is the link for further information: https://keras.io/api/layers/regularizers/