I'm a new student. I think the class code is correct. But the 'def update_vocab' is not applied. (Under bar is always appeared at update_vocab(q),update_vocab(a)) How can I fix this problem? Is the def update_vocab is wrong?
class sequence:
id_to_char = {}
char_to_id = {}
def update_vocab(txt):
chars = list(txt)
for i, char in enumerate(chars):
if char not in char_to_id:
tmp_id = len(char_to_id)
char_to_id[char] = tmp_id
id_to_char[tmp_id] = char
def load_data(file_name='addition.txt', seed=1984):
file_path = os.path.dirname(os.path.abspath('data.txt')) '/' file_name
if not os.path.exists(file_path):
print('No file: %s' % file_name)
return None
questions, answers = [], []
for line in open(file_path, 'r'):
idx = line.find('_')enter code here
questions.append(line[:idx])
answers.append(line[idx:-1])
for i in range(len(questions)):
q,a = questions[i], answers[i]
update_vocab(q)
update_vocab(a)
x = torch.zeros((len(questions), len(questions[0])), dtype=torch.int)
t = torch.zeros((len(questions), len(answers[0])), dtype=torch.int)
for i, sentence in enumerate(questions):
x[i] = [char_to_id[c] for c in list(sentence)]
for i, sentence in enumerate(answers):
t[i] = [char_to_id[c] for c in list(sentence)]
indices = torch.arange(len(x))
if seed is not None:
torch.random.seed(seed)
torch.random.shuffle(indices)
x = x[indices]
t = t[indices]
split_at = len(x) - len(x) // 10
(x_train, x_test) = x[:split_at], x[split_at:]
(t_train, t_test) = t[:split_at], t[split_at:]
return (x_train, t_train), (x_test, t_test)
def get_vocab():
return char_to_id, id_to_char'
CodePudding user response:
please add self, in param function. check
I try your code:
import os
import torch
class sequence:
def __init__(self):
self.id_to_char = {}
self.char_to_id = {}
def update_vocab(self, txt):
chars = list(txt)
for i, char in enumerate(chars):
if char not in self.char_to_id:
tmp_id = len(self.char_to_id)
self.char_to_id[char] = tmp_id
self.id_to_char[tmp_id] = char
def load_data(self, file_name='addition.txt', seed=1984):
file_path = os.path.dirname(os.path.abspath('data.txt')) '/' file_name
if not os.path.exists(file_path):
print('No file: %s' % file_name)
return None
questions, answers = [], []
for line in open(file_path, 'r'):
idx = line.find('_')
questions.append(line[:idx])
answers.append(line[idx:-1])
for i in range(len(questions)):
q,a = questions[i], answers[i]
update_vocab(q)
update_vocab(a)
x = torch.zeros((len(questions), len(questions[0])), dtype=torch.int)
t = torch.zeros((len(questions), len(answers[0])), dtype=torch.int)
for i, sentence in enumerate(questions):
x[i] = [self.char_to_id[c] for c in list(sentence)]
for i, sentence in enumerate(answers):
t[i] = [self.char_to_id[c] for c in list(sentence)]
indices = torch.arange(len(x))
if seed is not None:
torch.random.seed(seed)
torch.random.shuffle(indices)
x = x[indices]
t = t[indices]
split_at = len(x) - len(x) // 10
(x_train, x_test) = x[:split_at], x[split_at:]
(t_train, t_test) = t[:split_at], t[split_at:]
return (x_train, t_train), (x_test, t_test)
def get_vocab(self):
return self.char_to_id, self.id_to_char
seq = sequence()
seq.update_vocab('a')
print(seq.get_vocab())
CodePudding user response:
Wrong indentation: method load_data() always return None, try:
...
if not os.path.exists(file_path):
print('No file: %s' % file_name)
return None
...