I'm pretty new to working with Trax, a framework built by the Google Brain team to work with deep learning models as an alternative to TensorFlow. As a TensorFlow developer, I'm pretty used to the model.summary()
method (documented here) to display a full model summary, for example:
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 16, 303)] 0
_________________________________________________________________
bidirectional (Bidirectional (None, 16, 256) 442368
_________________________________________________________________
time_distributed (TimeDistri (None, 16, 22) 5654
=================================================================
Total params: 448,022
Trainable params: 448,022
Non-trainable params: 0
Is there something equivalent in Trax?
CodePudding user response:
Currently, there does not appear to be a method similar to .summary()
in Trax; the closest thing is that you can print the model. Adapting the example from the documentation:
from trax import layers as tl
model = tl.Serial(
tl.Embedding(vocab_size=8192, d_feature=256),
tl.Mean(axis=1), # Average on axis 1 (length of sentence).
tl.Dense(2), # Classify 2 classes.
)
print(model)
Result:
Serial[
Embedding_8192_256
Mean
Dense_2
]
Although nowhere as detailed as Tensorflow's model.summary()
, there is still useful info in the print output: notice that the parameters of the embedding layer are included in the printout; notice also that, if you change the model's last layer to tl.Dense(3)
, the respective output will change to Dense_3
.