Home > Mobile >  Force quit python multithreading by Pool
Force quit python multithreading by Pool

Time:11-10

I am using multi-threading to run memory consuming processes. To avoid memory swapping, I want to kill all multi-threaded processes when the memory exceeds a certain amount.

Python is 3.9 and I am using Jupyter notebook.

The following is a sample code to terminate the process when the memory usage exceeds 50Gb.

from multiprocessing import Pool, cpu_count

def my_func():
  a = MemoryConsumingObject()
  memory_usage_giga = os.popen('free -g').readlines()[1].split()[1:][1]
  if memory_usage_giga > 50:
    # I want quit multi-threading process! How do I quit?
    pass
  return a

my_list = list(np.arange(1,1000))
with Pool(cpu_count()) as p:
  result = p.imap(my_func, my_list)

CodePudding user response:

You can do this using an Event to communicate between the workers. You can't easily terminate the pool (the workers can't directly communicate with the main process in this case) but you can prevent the workers doing e.g. memory allocation when some condition has been met. In the following script, I simulate when to stop work by generating a random number. Once the condition to stop has been met, the workers still get called to process the work to be done, but they stop without actually doing anything if the stop condition is met.

from os import getpid
from multiprocessing import Pool, cpu_count, Event
from random import random

event = Event()

def my_func(arg):
    pid = getpid()
    if event.is_set():
        print('%s received %s, but aborting as stop condition was seen' % (pid, arg))
        return None
    r = random()
    if r < 0.3:
        print('%s received %s, %.2f stop condition met, aborting and letting everyone else know' % (pid, arg, r))
        event.set()
        return
    print("%s received: %s, %.2f stop condition not met, do work" % (getpid(), arg, r))
    # do the actual work here
    return r

my_list = list(range(10))

def main():
    with Pool(cpu_count()) as p:
        result = p.map(my_func, my_list)
        p.close()
        p.join()
        print(list(result))

if __name__ == '__main__':
    main()

An example run prints:

18786 received: 0, 0.67 stop condition not met, do work
18786 received: 1, 0.38 stop condition not met, do work
18786 received: 2, 0.51 stop condition not met, do work
18786 received: 3, 0.34 stop condition not met, do work
18786 received: 4, 0.72 stop condition not met, do work
18786 received: 5, 0.82 stop condition not met, do work
18786 received 6, 0.00 stop condition met, aborting and letting everyone else know
18786 received 7, but aborting as stop condition was seen
18786 received 8, but aborting as stop condition was seen
18786 received 9, but aborting as stop condition was seen
[0.6733142995176965, 0.3830673860039788, 0.5053762644489409, 0.3437097578584267, 0.7211013474170365, 0.816546904830295, None, None, None, None]

Note that the processes might not always run in quite the order you might expect - it's concurrent processing, after all. In my case the CPU count is only 2, hence one worker process is getting everything.

  • Related