Goal: instantiate unet_learner()
using weights
.
weights
is a str
that I bring in from a user-defined .yaml
file; hence eval().
file_path
and training
are classes that hold parameters.
Code:
import numpy as np
from fastai.vision.all import *
def train(dls, file_path, training):
labels = np.loadtxt(file_path.labels, dtype=str)
weights = torch.tensor(eval(training.weights))
print('#################')
print(weights)
print(type(weights))
print('#################')
learner = unet_learner(dls, training.architecture,loss_func=CrossEntropyLossFlat(
axis=1,
weight=weights)
)
return learner.load(file_path.weights)
Placing torch.tensor()
around weights
again in the parameter line doesn't help. Same error.
Traceback:
(venv) me@ubuntu-pcs:~/PycharmProjects/project$ python pdl1_lung_train/main.py
/home/me/miniconda3/envs/venv/lib/python3.7/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /opt/conda/conda-bld/pytorch_1607370156314/work/c10/cuda/CUDAFunctions.cpp:100.)
return torch._C._cuda_getDeviceCount() > 0
#################
tensor([0.4000, 0.9000])
<class 'torch.Tensor'>
#################
Traceback (most recent call last):
File "pdl1_lung_train/main.py", line 27, in <module>
main(ROOT)
File "pdl1_lung_train/main.py", line 19, in main
learner = train(dls, file_path, training)
File "/home/me/PycharmProjects/project/pdl1_lung_train/train.py", line 16, in train
weight=weights))
File "/home/me/miniconda3/envs/venv/lib/python3.7/site-packages/fastai/vision/learner.py", line 267, in unet_learner
model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs)
File "/home/me/miniconda3/envs/venv/lib/python3.7/site-packages/fastai/vision/learner.py", line 243, in create_unet_model
model = arch(pretrained)
TypeError: 'str' object is not callable
Please let me know if I need to add other info. to post.
CodePudding user response:
I might be wrong but I think your training.architecture
is a string. But according to unet_learner
documentation it has to be callable.