I want to print some objects inside below call function using print command but it prints nothig when the code runs successfully. I am reading (THIS) keras debugging tutorial but still I am confused why it is not printing anything.
#Hyperparams
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 100
num_epochs = 1
image_size = 72 # We'll resize input images to this size
patch_size = 6 # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
projection_dim * 2,
projection_dim,
] # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]
I want to print the (positions and encoded) inside below call function. For that I used print but it's not working. while HERE, they had done like this.
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim, position_embedding):
super.__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) self.position_embedding(positions)
print("Encoded shape is:",encoded.shape)
print("pos.shape is:", positions.shape)
return encoded
CodePudding user response:
For the code shared in the comments. Add this as the first code cell before the imports.
%