I have an object of a custom class that I am trying to serialize and permanently store.
When I serialize it, store it, load it and use it in the same run, it works fine. It only messes up when I've ended the process and then try to load it again from the pickle file. This is the code that works fine:
first_model = NgramModel(3, name="debug")
for paragraph in text:
first_model.train(paragraph_to_sentences(text))
# paragraph to sentences just uses regex to do the equivalent of splitting by punctuation
print(first_model.context_options)
# context_options is a dict (counter)
first_model = NgramModel.load_existing_model("debug")
#load_existing_model loads the pickle file. Look in the class code
print(first_model.context_options)
However, when I run this alone, it prints an empty counter:
first_model = NgramModel.load_existing_model("debug")
print(first_model.context_options)
This is a shortened version of the class file (the only two methods that touch the pickle/dill are update_pickle_state
and load_existing_model
):
import os
import dill
from itertools import count
from collections import Counter
from os import path
class NgramModel:
context_options: dict[tuple, set[str]] = {}
ngram_count: Counter[tuple] = Counter()
n = 0
pickle_path: str = None
num_paragraphs = 0
num_sentences = 0
def __init__(self, n: int, **kwargs):
self.n = n
self.pickle_path = NgramModel.pathify(kwargs.get('name', NgramModel.gen_pickle_name())) #use name if exists else generate random name
def train(self, paragraph_as_list: list[str]):
'''really the central method that coordinates everything else. Takes a list of sentences, generates data(n-grams) from each, updates the fields, and saves the instance (self) to a pickle file'''
self.num_paragraphs = 1
for sentence in paragraph_as_list:
self.num_sentences = 1
generated = self.generate_Ngrams(sentence)
self.ngram_count.update(generated)
for ngram in generated:
self.add_to_set(ngram)
self.update_pickle_state()
def update_pickle_state(self):
'''saves instance to pickle file'''
file = open(self.pickle_path, "wb")
dill.dump(self, file)
file.close()
@staticmethod
def load_existing_model(name: str):
'''returns object from pickle file'''
path = NgramModel.pathify(name)
file = open(path, "rb")
obj: NgramModel = dill.load(file)
return obj
def generate_Ngrams(self, string: str):
'''ref: https://www.analyticsvidhya.com/blog/2021/09/what-are-n-grams-and-how-to-implement-them-in-python/'''
words = string.split(" ")
words = ["<start>"] * (self.n - 1) words ["<end>"] * (self.n - 1)
list_of_tup = []
for i in range(len(words) 1 - self.n):
list_of_tup.append((tuple(words[i j] for j in range(self.n - 1)), words[i self.n - 1]))
return list_of_tup
def add_to_set(self, ngram: tuple[tuple[str, ...], str]):
if ngram[0] not in self.context_options:
self.context_options[ngram[0]] = set()
self.context_options[ngram[0]].add(ngram[1])
@staticmethod
def pathify(name):
'''converts name to path'''
return f"models/{name}.pickle"
@staticmethod
def gen_pickle_name():
for i in count():
new_name = f"unnamed-pickle-{i}"
if not path.exists(NgramModel.pathify(new_name)):
return new_name
All the other fields print properly and are complete and correct except the two dicts
CodePudding user response:
The problem is that is that context_options
is a mutable class-member, not an instance member. If I had to guess, dill is only pickling instance members, since the class definition holds class members. That would account for why you see a "filled-out" context_options when you're working in the same shell but not when you load fresh — you're using the dirtied class member in the former case.
It's for stuff like this that you generally don't want to use mutable class members (or similarly, mutable default values in function signatures). More typical is to use something like context_options: dict[tuple, set[str]] = None
and then check if it's None
in the __init__
to set it to a default value, e.g., an empty dict. Alternatively, you could use a @dataclass and provide a field initializer, i.e.
@dataclasses.dataclass
class NgramModel:
context_options: dict[tuple, set[str]] = dataclasses.field(default_factory=dict)
...
You can observe what I mean about it being a mutable class member with, for instance...
if __name__ == '__main__':
ng = NgramModel(3, name="debug")
print(ng.context_options) # {}
ng.context_options[("foo", "bar")] = {"baz", "qux"}
print(ng.context_options) # {('foo', 'bar'): {'baz', 'qux'}}
ng2 = NgramModel(3, name="debug")
print(ng2.context_options) # {('foo', 'bar'): {'baz', 'qux'}}
I would expect a brand new ng2 to have the same context that the brand new ng had - empty (or whatever an appropriate default is).