I'm developing a project in PyTorch which requires returning the weights (state_dict) of a pytorch model as a Flask endpoint response. To better explain it, the simplest code could be:
@app.endpoint('/send_weights', methods=['GET', 'POST'])
def send_weights():
model_weights = model.state_dict() # It is a dict[str, torch.tensor]
return model_weights
However it is not as simple because the torch.tensor is not JSON serializable, so, I've tried to convert them to a list (JSON serializable object) and it works:
@app.endpoint('/send_weights', methods=['GET', 'POST'])
def send_weights():
model_weights = model.state_dict() # It is a dict[str, torch.tensor]
model_weights = {k:v.tolist() for k,v in model_weights.items()}
return model_weights
However this process is very slow and it doesn't meet my requirements. I was trying to convert the tensors to bytes but the code gives the same problem, bytes is not JSON serializable. So, I'm thinking that the json response won't be the solution. I'm not an expert in Flask but I've read about the flask send_file method, however I not sure how to use it in this case (not even sure this would be a possible solution), I haven't got a mimetype for the dictionary.
Does anybody know a better way to do this?
CodePudding user response:
I've just found a solution. It is based on return the weights as a file using the send_file method from flask with the weights saved on a binary using torch.save and the mimetype 'application/octet-stream' as appears in this question. The final endopint code will be:
@app.endpoint('/send_weights', methods=['GET', 'POST'])
def send_weights():
model_weights = model.state_dict()
to_send = io.BytesIO()
torch.save(model_weights, to_send, _use_new_zipfile_serialization=False)
to_send.seek(0)
return send_file(to_send, mimetype='application/octet-stream')
And, to load it in the other side:
weights = torch.load(io.BytesIO(response.content))
Hope it will help somebody.