I'm using RLlib to train a reinforcement learning policy (PPO algorithm). I want to see the weights in the neural network underlying the policy.
After digging through RLlib's PPO
object, I found the TensorFlow Graph
object. I thought that I would find the weights of the neural network there. But I can't find them. I see that this graph has ~1,000 nodes but I can't for the life of me find where TensorFlow is hiding the actual weights for the neural network. I looked through the nodes. I was told to keep an eye out for tf.Variable
objects, but I couldn't find any. The closest thing I could find are nodes of type ReadVariableOp
, but I couldn't find a tf.Variable
in them. I did find a tf.Tensor
in there, but I'm not sure whether it holds actual numbers, and if so how to get them.
Where do I find the weights of my neural network?
CodePudding user response:
In a single-agent setup, do this:
weights = algo.get_policy().get_state()["weights"]
In a multi-agent setup, you'll need to specify the policy name:
weights = algo.get_policy(policy_name).get_state()["weights"]