Home > Enterprise >  Tensorflow concatenate unknown number of inputs
Tensorflow concatenate unknown number of inputs

Time:10-07

I want to use the same neural network algorithm for different problems with a different number of input parameters. For now I use this function:

# main function to be called  
def call(self,x0,x1=None,x2=None,x3=None,x4=None,x5=None,x6=None,x7=None,x8=None,x9=None): 
    # define input vector as time-space pairs
    if x1 == None:
        X = x0
    elif x2 == None:
        X = tf.concat([x0,x1],1)
    elif x3 == None:
        X = tf.concat([x0,x1,x2],1)
    elif x4 == None:
        X = tf.concat([x0,x1,x2,x3],1)
    elif x5 == None:
        X = tf.concat([x0,x1,x2,x3,x4],1)
    elif x6 == None:
        X = tf.concat([x0,x1,x2,x3,x4,x5],1)
    elif x7 == None:
        X = tf.concat([x0,x1,x2,x3,x4,x5,x6],1)     
    elif x8 == None:
        X = tf.concat([x0,x1,x2,x3,x4,x5,x6,x7],1)
    elif x9 == None:
        X = tf.concat([x0,x1,x2,x3,x4,x5,x6,x7,x8],1)
    else:
        X = tf.concat([x0,x1,x2,x3,x4,x5,x6,x7,x8,x9],1)

It works, but is there a better (shorter/faster) way to do this?

CodePudding user response:

You can use the *args syntax:

def f(*args):
    print([*args])

>>> f("test")
['test']
>>> f("foo", "bar", 42)
['foo', 'bar', 42]
>>> f()
[]

which would look like this for you

def call(self, *args): 
    X = tf.concat([*args], 1)
  • Related