Home > Software engineering >  saving weights of a tensorflow model in Databricks
saving weights of a tensorflow model in Databricks

Time:01-07

In a Databricks notebook which is running on Cluster1 when I do

path='dbfs:/Shared/P1-Prediction/Weights_folder/Weights'
model.save_weights(path)

and then immediately try

ls 'dbfs:/Shared/P1-Prediction/Weights_folder'

I see the actual weights file in the output display

But When I run the exact same command ls 'dbfs:/Shared/P1-Prediction/Weights_folder' on a different Databricks notebook which is running on cluster 2, I am getting the error

ls: cannot access 'dbfs:/Shared/P1-Prediction/Weights_folder': No such file or directory

I am not able to intrepret this. Does that mean my "save_weights" is saving the weights in clusters memory and not in an actual physical location? If so is there a solution for it. Any help is highly appreciated.

CodePudding user response:

Tensorflow uses Python's local file API that doesn't work with dbfs:/... - you need to change path to use /dbfs/... instead of dbfs:/....

But really, it could be better to log model using MLflow, in this case you can easily load it for inference. See documentation and maybe this example.

  • Related