Home > other >  PyTorch equivalent for Keras sequential model
PyTorch equivalent for Keras sequential model

Time:01-10

How to get the perfect copy of this Keras sequential network in PyTorch?

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])

CodePudding user response:

This is a snippet that works for this case:

    model_torch = nn.Sequential(
         nn.Flatten(), 
         nn.Linear(28*28, 128), 
         nn.ReLU(), 
         nn.Linear(128, 10), 
    )
  •  Tags:  
  • Related