i am following an example of the https://github.com/google/lightweight_mmm but instead of using the default setting for scalars, which is mean:
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
i need to use the lambda function:
lambda x: jnp.mean(x[x > 0])
How can this be done? I tried couple of things, but since i am a complete beginner, i feel lost.
So i have tried:
lambda x: jnp.mean(x[x > 0])
media_scaler = preprocessing.CustomScaler(divide_operation=x)
and
lambda x: jnp.mean(x[x > 0])
media_scaler = preprocessing.CustomScaler(divide_operation=lambda)
None of these work.
CodePudding user response:
This should do it
div = lambda x: jnp.mean(x[x > 0])
media_scaler = preprocessing.CustomScaler(divide_operation=div)