Home > OS >  Get the neural network weights out of a Tensorflow `Graph`
Get the neural network weights out of a Tensorflow `Graph`

Time:10-23

I'm using RLlib to train a reinforcement learning policy (PPO algorithm). I want to see the weights in the neural network underlying the policy.

After digging through RLlib's PPO object, I found the TensorFlow Graph object. I thought that I would find the weights of the neural network there. But I can't find them. I see that this graph has ~1,000 nodes but I can't for the life of me find where TensorFlow is hiding the actual weights for the neural network. I looked through the nodes. I was told to keep an eye out for tf.Variable objects, but I couldn't find any. The closest thing I could find are nodes of type ReadVariableOp, but I couldn't find a tf.Variable in them. I did find a tf.Tensor in there, but I'm not sure whether it holds actual numbers, and if so how to get them.

Where do I find the weights of my neural network?

CodePudding user response:

In a single-agent setup, do this:

weights = algo.get_policy().get_state()["weights"]

In a multi-agent setup, you'll need to specify the policy name:

weights = algo.get_policy(policy_name).get_state()["weights"]
  • Related