I'm trying to implement multiprocesssing for image processing from multiple folders, when I use multiprocessing library I'm trying to pass multiple arguments to the pool function but Im getting type error, my code is as below
def augment(args):
imgs = args[0]
bpath = args[1]
dictcount = args[2]
for img in imgs:
cv2.imread(img)
def main():
basepath = 'some path'
count_dict = {some dict}
allImagepaths = sorted(list(paths.list_images(basepath))
procs = no_cores if no_cores > 0 else cpu_count()
procids = list(range(0,procs))
noImgsproc = len(allImagepaths) /float(procs)
noImgsproc = int(np.ceil(noImgsproc))
chunkpaths = list(chunk(allImagepaths,noImgsproc))
payloads = [ ]
for (i,imgpaths) in enumberate(chunkpaths):
data = {"id":i,"input_paths":imgpaths}
payloads.append(data)
pool = Pool(processes = procs)
pool.map(augment,[[payloads,basepath,count_dict]])
when I run this I get error as below
File "aug_imgs.pylin 306, in main pool.map(augment,[[payloads,basepath,count_dict]])
File "C:\Users\rob\anaconda3\envs\retinanet\lib\multiprocessing\pool.py", line 364, in map return self._map_async(func, iterable, mapstar, chunksize).get()
File "C:\Users\rob\anaconda3\envs\retinanet\lib\multiprocessing\pool.py", line 771, in get raise self._value
TypeError: list indices must be integers or slices, not str
but if I pass only one argument with a function code runs perfectly without error there I'm passing as imgs = args in augment() as below
pool.map(augment,payloads)
what am I doing wrong any help solving this will be appreciated, thanks in advance
CodePudding user response:
Break down your processing function to work on a single image, then pass the collection of images to the process pool. You can set additional constant parameters by using functools.partial
:
from functools import partial
def augment(base_path, count_dict, image):
# do something with base_path & count_dict
cv2.imread(image)
def main():
# assuming base_path and count_dict are defined here
partial_augment = partial(augment, base_path, count_dict)
pool = Pool()
pool.map(partial_augment, images)