Home > database >  How to plot Graph-neural-network model-graph when using tensorflow Model Subclass API with spektral
How to plot Graph-neural-network model-graph when using tensorflow Model Subclass API with spektral

Time:10-03

I am unable to plot graph-neural-networking. I have seen few related questions(1, 2, 3) to this topic but their answers do not apply to graph-neural-networks.

What makes it different is that the input vector include objects of different dimensions e.g. properties matrix dimension is [n_nodes, n_node_features], adjacency matrix dimension is [n_nodes, n_nodes] etc. Here is the example of my Model:

class GIN0(Model):
    def __init__(self, channels, n_layers):
        super().__init__()
        self.conv1 = GINConv(channels, epsilon=0, mlp_hidden=[channels, channels])
        self.convs = []
        for _ in range(1, n_layers):
            self.convs.append(
                GINConv(channels, epsilon=0, mlp_hidden=[channels, channels])
            )
        self.pool = GlobalAvgPool()
        self.dense1 = Dense(channels, activation="relu")
        self.dropout = Dropout(0.5)
        self.dense2 = Dense(channels, activation="relu")

    def call(self, inputs):
        x, a, i = inputs
        x = self.conv1([x, a])
        for conv in self.convs:
            x = conv([x, a])
        x = self.pool([x, i])
        x = self.dense1(x)
        x = self.dropout(x)
        return self.dense2(x)

One of the answers in 2 suggested to add build_graph function as follows:

class my_model(Model):
    def __init__(self, dim):
        super(my_model, self).__init__()
        self.Base  = VGG16(input_shape=(dim), include_top = False, weights = 'imagenet')
        self.GAP   = L.GlobalAveragePooling2D()
        self.BAT   = L.BatchNormalization()
        self.DROP  = L.Dropout(rate=0.1)
        self.DENS  = L.Dense(256, activation='relu', name = 'dense_A')
        self.OUT   = L.Dense(1, activation='sigmoid')
    
    def call(self, inputs):
        x  = self.Base(inputs)
        g  = self.GAP(x)
        b  = self.BAT(g)
        d  = self.DROP(b)
        d  = self.DENS(d)
        return self.OUT(d)
    
    # AFAIK: The most convenient method to print model.summary() 
    # similar to the sequential or functional API like.
    def build_graph(self):
        x = Input(shape=(dim))
        return Model(inputs=[x], outputs=self.call(x))

dim = (124,124,3)
model = my_model((dim))
model.build((None, *dim))
model.build_graph().summary()

However, I am not sure how to define dim or Input Layer using tf.keras.layers.Input for such a hybrid data-structure as described above.

Any suggestions?

CodePudding user response:

Here is the minimal code to plot such subclass multi-input model. Note, as stated in the comment above, there are some issue of your GINConv which is from spektral and it's not related to the main query. So, I will give general soluton of such multi-input modeling scenarios. To make it work with your speckral, please reach to the package author for further discussion.


From specktral repo, here, I got the idea the shape of the input tensors.

x, y = next(iter(loader_tr))

bs_x = list(x[0].shape)
bs_y = list(x[1].shape)
bs_z = list(x[2].shape)

bs_x, bs_y, bs_z
([1067, 4], [1067, 1067], [1067])

Similar model, it also takes same amount of inputs and with same shape. But without GINConv.

class GIN0(Model):
    def __init__(self, channels, n_layers):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv1D(channels, 3, activation='relu')
        self.conv2 = tf.keras.layers.Conv1D(channels, 3, activation='relu')

        self.dense1 = Dense(channels, activation="relu")
        self.dropout = Dropout(0.5)
        self.dense2 = Dense(n_out, activation="softmax")

    def call(self, inputs):
        x, a, i = inputs

        x = self.conv1(x)
        x = tf.keras.layers.GlobalAveragePooling1D()(x)
        a = self.conv2(a)
        a = tf.keras.layers.GlobalAveragePooling1D()(a)

        x = tf.keras.layers.Concatenate(axis=1)([a, x, i])
        x = self.dense1(x)
        x = self.dropout(x)
        return self.dense2(x)
    
    def build_graph(self):
        x = tf.keras.Input(shape=bs_x)
        y = tf.keras.Input(shape=bs_y)
        z = tf.keras.Input(shape=bs_z)
        return tf.keras.Model(
            inputs=[x, y, z], 
            outputs=self.call([x, y, z])
        )
model = GIN0(channels, layers)
model.build(
    [
        (None, *bs_x), 
        (None, *bs_y), 
        (None, *bs_z)
    ]
)

# OK
model.build_graph().summary()

# OK
tf.keras.utils.plot_model(
    model.build_graph(), show_shapes=True
)
  • Related