Home > Blockchain >  Problem with defining derived class in python
Problem with defining derived class in python

Time:06-05

I am learning how to use classes in python to alter some Keras methods to create various forms of Generative Adversarial Networks, GANs. In this case, I am trying to implement the gradient penalty modification to the Wasserstein GAN architecture based on an example from the Keras website: https://keras.io/examples/generative/wgan_gp/ Since I am new to classes and inheritance, I decided to play around with some simple examples that try to match the Keras example. I am confused on this segment of code:

class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

I tried making my own simple class and an inherited class following this example similar to the example from the W3schools website https://www.w3schools.com/python/python_inheritance.asp :

class Person:
    def __init__(self, fname, lname):
        self.firstname = fname
        self.lastname = lname
    def printname(self):
        print(self.firstname, self.lastname)

class Student(Person):
    def __init__(self, age, height):
        super(Student, self).__init__()
        self.age = age
        self.height = height

I test it with:

s = Student(1,2)

I get the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_2168/1913336311.py in <module>
----> 1 s = Student(1,2)

~\AppData\Local\Temp/ipykernel_2168/376196858.py in __init__(self, age, height)
      8 class Student(Person):
      9     def __init__(self, age, height):
---> 10         super(Student, self).__init__()
     11         self.age = age
     12         self.height = height

TypeError: __init__() missing 2 required positional arguments: 'fname' and 'lname'

How is it possible to have the empty() after "super(WGAN, self).__ init __ in the Keras code but it is not working in mine. I feel like I am taking the same approach. Thanks

CodePudding user response:

The reason is that keras.Model's __init__ does not take positional arguments, whereas your class Person does. That way, you can call keras.Model constructor without any arguments, but you cannot call the Person class constructor without defining the fname and lname args.

CodePudding user response:

super refers to the parent class. Parent class of Student is Person. Person takes 2 arguments, so __init__() does not work. If you call Person class, you have to provide 2 inputs.

Parent class of WGAN is keras.Model, that does not require any arguments thus __init__() works.

  • Related