Home > Software design >  What's the best way to load a ML model with Django
What's the best way to load a ML model with Django

Time:06-22

I'm deploying a Machine Learning model (Named Entity Recognition) with Django. In short, the user chooses a field (Politics or Science for example) and writes a text in a search area. Then the model identifies the named entities in the text.

My problem is that the ML model (encoder) is loaded each time the view is triggered, which slows down the process. Any idea how to optimize this and load it only once ?

My views.py :

def search_view(request):

   if request.POST: 
          field = request.POST['field']
          query = request.POST['query']
          encoder = load_encoder(field)
          results = Ner_model(query,encoder)
          context['result'] = results
   return render(request,'ner/results.html', context)
      

Load encoder function:

def load_encoder(field):
 path_encoder = os.paths.join(field,'field_encoder')
 encoder = AutoTokenizer.from_pretrained(path_encoder)
 return encoder

Thanks!

CodePudding user response:

It's better to load all possible encoders and models at first. then for each condition, call proper model. for example if you have 2 fields:

encoders = {}
fields = [field1, field2]
for field in fields:
    path_encoder = os.path.join(field,'field_encoder')
    encoder = AutoTokenizer.from_pretrained(path_encoder)
    encoders[field] = encoder

def load_encoder(field):
    encoder = encoders.get(field)
    return encoder

CodePudding user response:

Inspired by @Masoud's suggestion, I defined 2 global variables : one that would store the encoder loaded and one to keep track of the field chosen.. even though global variables are generally not so recommended..

So that gives something like this :

encoder = None
old_field = ''

def search_view(request):
   global encoder
   global old_field

   if request.POST: 
          field = request.POST['field']
          query = request.POST['query']
          if (not encoder) or (old_field!=field):
              encoder = load_encoder(field)
          results = Ner_model(query,encoder)
          context['result'] = results
  
   return render(request,'ner/results.html', context)
  
  • Related