I am having a fastapi to do predictions and returing the output as a response but i have implemented input checking where if the user gives unsupported input it returns a Invalid smile
but the problem here is the response dictionary is not replaced.
when i do the prediction i got this response
{"result":{"interaction_map":[[15.0,5.0,14.0,15.0,15.0],[19.0,7.0,20.0,19.0,19.0],[13.0,6.0,18.0,13.0,13.0],[15.0,5.0,14.0,15.0,15.0],[15.0,5.0,14.0,15.0,15.0]],"predictions":-3.405024290084839}}
But when i give a wrong input i got this response
{"result":{"interaction_map":[[15.0,5.0,14.0,15.0,15.0],[19.0,7.0,20.0,19.0,19.0],[13.0,6.0,18.0,13.0,13.0],[15.0,5.0,14.0,15.0,15.0],[15.0,5.0,14.0,15.0,15.0]],"predictions":"invalid SMILES"}}
But i am expecting this response
{"predictions":"invalid SMILES"}
This is the code i am using
response = {}
async def predictions(solute, solvent):
m = Chem.MolFromSmiles(solute,sanitize=False)
n = Chem.MolFromSmiles(solvent,sanitize=False)
if (m == None or n == None):
response['predictions']= 'invalid SMILES'
print('invalid SMILES')
else:
mol = Chem.MolFromSmiles(solute)
mol = Chem.AddHs(mol)
solute = Chem.MolToSmiles(mol)
solute_graph = get_graph_from_smile(solute)
mol = Chem.MolFromSmiles(solvent)
mol = Chem.AddHs(mol)
solvent = Chem.MolToSmiles(mol)
solvent_graph = get_graph_from_smile(solvent)
delta_g, interaction_map = model([solute_graph.to(device), solvent_graph.to(device)])
interaction_map_one = torch.trunc(interaction_map)
response["interaction_map"] = (interaction_map_one.detach().numpy()).tolist()
response["predictions"] = delta_g.item()
@app.get('/predict_solubility')
async def post():
return {'result': response}
@app.get('/predict')
async def predict(background_tasks: BackgroundTasks,solute,solvent):
background_tasks.add_task(predictions,solute,solvent)
return {'success'}
CodePudding user response:
The problem is that response
is a global variable so the elements you write into it on your first request still stay there on the second request.
One quickfix would be to clear the response
dict at the beginning of the request to /predict_solubility
:
async def predictions(solute, solvent):
response.clear()
But in general it seems problematic to me that a single request to /predict
sets some sort of global state which might get overwritten, rather than returning a jobid
where you can check in on a specific job/task.