Home > front end >  Cannot correctly stop thread pool
Cannot correctly stop thread pool

Time:03-23

Here is my ThreadPool implementation. I've tried it in simple function main and cannot stop it correctly, destructor is called before threads start and whole program finish in the deadlock (on t.join()) because condition variable has been called before thread reached wait function.

Any ideas how to fix it? Or there is a better way to implement it?

ThreadPool.cpp

#include <condition_variable>
#include <future>
#include <iostream>
#include <mutex>
#include <thread>
#include <vector>


namespace Concurrency {
template <typename RetType>
class StandardThreadPool : public ThreadPool<RetType> {
private:
  typedef std::function<RetType()> taskType;
  ThreadSafeQueue<std::packaged_task<RetType()>> queue;

  std::mutex queueMutex;
  std::condition_variable queueCondition;

  std::vector<std::thread> poolThreads;
  std::atomic<bool> stopThreadsFlag{false};

  void threadWork() {
    std::cout << "thread:" << std::this_thread::get_id() << " started\n";
    std::unique_lock<std::mutex> lock(queueMutex);
    while (true) {
      queueCondition.wait(lock);

      if (stopThreadsFlag.load())
        break;

      auto task = queue.Pop();

      if (task)
        (*task)();
    }
    std::cout << "thread:" << std::this_thread::get_id() << " finished\n";
  }

  void initThreadPool() {
    poolThreads.resize(StandardThreadPool<RetType>::maxThreads);
  }

  void startThreads() {
    for (int i = 0; i < StandardThreadPool<RetType>::maxThreads; i  ) {
      poolThreads[i] =
          std::thread(&StandardThreadPool<RetType>::threadWork, this);
    }
  }

  void terminateThreads() {
    stopThreadsFlag.store(true);
    queueCondition.notify_all();

    for (auto &t : poolThreads) {
      t.join();
    }
  }

public:
  StandardThreadPool(int maxThreads) : ThreadPool<RetType>(maxThreads) {
    initThreadPool();
    startThreads();
  }

  std::future<RetType> virtual Push(taskType &&task) override {
    std::packaged_task<RetType()> pt = std::packaged_task<RetType()>(task);
    auto future = pt.get_future();
    queue.Push(std::move(pt));

    queueCondition.notify_one();
    return future;
  }

  ~StandardThreadPool<RetType>() {
    std::cout << "destructor called\n";
    terminateThreads(); }
};

} // namespace Concurrency


namespace Concurrency {
template <typename T> class ThreadSafeQueue {
private:
  struct node {
    std::shared_ptr<T> data;
    std::unique_ptr<node> next;
  };

  std::mutex headMutex;
  std::mutex tailMutex;

  std::unique_ptr<node> head;
  node *tail;

  node *getTail() {
    std::lock_guard<std::mutex> lock(tailMutex);
    return tail;
  }

  std::unique_ptr<node> popHead() {
    std::lock_guard<std::mutex> lock(headMutex);

    if (head.get() == getTail())
      return nullptr;

    std::unique_ptr<node> oldHead(std::move(head));
    head = std::move(oldHead->next);

    return oldHead;
  }

public:
  ThreadSafeQueue() : head(new node), tail(head.get()) {}

  std::shared_ptr<T> Pop() {
    auto oldHead = popHead();

    return oldHead ? oldHead->data : nullptr;
  }

  void Push(T &newValue) {
    auto newData = std::make_shared<T>(std::forward<T>(newValue));
    std::unique_ptr<node> pNew(new node);

    auto newTail = pNew.get();

    std::lock_guard<std::mutex> lock(tailMutex);
    tail->data = newData;
    tail->next = std::move(pNew);

    tail = newTail;
  }

  void Push(T &&newValue) {
    auto newData = std::make_shared<T>(std::move(newValue));
    std::unique_ptr<node> pNew(new node);

    auto newTail = pNew.get();

    std::lock_guard<std::mutex> lock(tailMutex);
    tail->data = newData;
    tail->next = std::move(pNew);

    tail = newTail;
  }

  ThreadSafeQueue(const ThreadSafeQueue &) = delete;
  ThreadSafeQueue &operator=(const ThreadSafeQueue &) = delete;
};
} // namespace Concurrency
#include <functional>
#include <future>

namespace Concurrency {

template <typename RetType> class ThreadPool {
public:
  int maxThreads;

public:
  typedef std::function<RetType()> taskType;
  ThreadPool(int maxThreads):maxThreads(maxThreads){}

  virtual std::future<RetType> Push(taskType &&newTask) = 0;

  ThreadPool(const ThreadPool &) = delete;
  ThreadPool(const ThreadPool &&) = delete;
};
} // namespace Concurrency

main.cpp

int main() {
  Concurrency::StandardThreadPool<int> th(1);
  auto fun = []() {
    std::cout << "function running\n";
    return 2;
  };

  th.Push(fun);

  return EXIT_SUCCESS;
}

CodePudding user response:

First, a correct threadsafe queue.

template<class T>
struct threadsafe_queue {
  [[nodiscard]] std::optional<T> pop() {
    auto l = lock();
    cv.wait(l, [&]{ return is_aborted() || !data.empty(); });
    if (is_aborted())
      return {};
    auto r = std::move(data.front());
    data.pop_front();
    cv.notify_all(); // for wait_until_empty
    return r; // might need std::move here, depending on compiler version
  }
  bool push(T t) {
    auto l = lock();
    if (is_aborted()) return false;
    data.push_back(std::move(t));
    cv.notify_one();
    return true;
  }
  void set_abort_flag() {
    auto l = lock(); // still need this
    aborted = true;
    data.clear();
    cv.notify_all();
  }
  [[nodiscard]] bool is_aborted() const { return aborted; }
  void wait_until_empty() {
    auto l = lock();
    cv.wait(l, [&]{ return data.empty(); });
  }
private:
  std::unique_lock<std::mutex> lock() {
    return std::unique_lock<std::mutex>(m);
  }
  std::condition_variable cv;
  std::mutex m;
  std::atomic<bool> aborted{false};
  std::deque<T> data;
};

this handles abort and the like internally.

Our threadpool then becomes:

struct threadpool {
  explicit threadpool(std::size_t count)
  {
    for (std::size_t i = 0; i < count;   i) {
      threads.emplace_back([&]{
        // abort handled by empty pop:
        while( auto f = queue.pop() ) {
          (*f)();
        }
      });
    }
  }
  void set_abort_flag() {
    queue.set_abort_flag();
  }
  [[nodiscard]] bool is_aborted() const {
    return queue.is_aborted();
  }
  ~threadpool() {
    queue.wait_until_empty();
    queue.set_abort_flag(); // get threads to leave the queue
    for (std::thread& t:threads)
      t.join();
  }
  template<class F,
    class R=typename std::result_of<F()>::type
  >
  std::future<R> push_task( F f ) {
    std::packaged_task<R()> task( std::move(f) );
    auto ret = task.get_future();
    if (queue.push( std::packaged_task<void()>(std::move(task)) )) // wait, this works?  Yes it does.
      return ret;
    else
      return {}; // cannot push, already aborted
  }
private:
  // yes, void.  This is evil but it works
  threadsafe_queue<std::packaged_task<void()>> queue;
  std::vector<std::thread> threads;
};

in you can swap the std::optional for std::unique_ptr. More runtime overhead.

The trick here is that a std::packaged_task<void()> can store a std::packaged_task<R()>. And we don't need the return value in the queue. So one thread pool can handle any number of different return values in tasks -- it doesn't care.

I only join the threads on thread_pool destruction. I could do it after an abort as well.

Destroying a thread_pool waits until all tasks are complete. Note that aborting a thread_pool may not abort tasks in progress. One thing that you probably want to add is the option of passing an abort API/flag to the tasks, so they can abort early if asked.

Getting this industrial scale is hard, because ideally all blocking in a task would also pay attention to the abort possibility.

Live example.

You could add a 2nd cv to notify after pops, which only wait_until_empty waits on. That might safe you some spurious wakeups.

  • Related