Home > database >  Using an LSTM to predict a category
Using an LSTM to predict a category

Time:01-24

My dataset contains a curve of data points with a column labeled "BUY". "BUY" indicates that the curve reached a local maxima or minima at this point (not fully on accurate due to smoothing to find these points but mainly within 1 row). 0 indicates it is not a maxima or minima, 1 indicates it is a maxima, 2 indicates it is a minima.

<DATE>,<CLOSE>,<VOL>,<BUY>
01/04/21;09:35:00,728.25,37290,0
01/04/21;09:40:00,728.0,31059,0
01/04/21;09:45:00,742.4,44956,0
01/04/21;09:50:00,740.03,27251,2
01/04/21;09:55:00,737.69,22765,0
01/04/21;10:00:00,737.0,9703,0
01/04/21;10:05:00,738.3,16369,1
01/04/21;10:10:00,735.41,17772,0
...

The issue is that this leaves a very large number of 0's in my dataset that I suspect are causing issues with the classification of if the curve is at a maxima or a minima or neither. This is due to the predictions always predicting 0.

This is my current model.

model = Sequential()
model.add(LSTM(units=100, return_sequences=True, dropout=0.2))
model.add(LSTM(units=50, return_sequences=False, dropout=0.2))
model.add(Dense(3, activation='softmax'))

Please let me know if there are any changes I should make.

Additionally, I was trying to create a double hurdle to classify if the input is a zero, and if not, classify what the input actually is but I was unable to figure it out.

CodePudding user response:

you need to predict also the zeros, and to be accurate also on them, otherwise at test time you can not infer that information.

Obviously the network will be biased to predict 0 because it's the most frequent, but this does not mean that it will always predict 0... if the dataset is not some stock prediction with covariates that brings no information to the prediction, the net will figure out when to predict 1/2 without overfitting

My suggestion is to use class_weight when fitting:

class_weight = {0: 1.,
                1: 2.,
                2: 2.}
model.fit(X_train, Y_train, ..., class_weight=class_weight)
  • Related