I have a small demo code where I have a Bean class and BeanRegister class. Bean class has two methods which are preInit()
and postInit()
. And BeanRegister is a thread class which has Bean class as a field. Here my code:
public static void main(String[] args) {
Bean beanA = new Bean();
BeanRegister beanRegister1 = new BeanRegister(beanA);
BeanRegister beanRegister2 = new BeanRegister(beanA);
beanRegister1.start();
beanRegister2.start();
}
private static class BeanRegister extends Thread {
private final Bean bean;
public BeanRegister(Bean bean) {
this.bean = bean;
}
@Override
public void run() {
try {
bean.preInit();
bean.postInit();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
private static class Bean {
public void preInit() throws InterruptedException {
Thread.sleep(new Random().nextInt(1000) * 5);
System.out.println("PreInit " Thread.currentThread().getName());
}
public void postInit() throws InterruptedException {
System.out.println("PostInit " Thread.currentThread().getName());
}
}
The problem I faced, is locking. I want to lock postInit()
in all threads, before those execution of preInit()
method in those threads is not finished. So, when all threads finished execution of preInit()
, then I want to allow threads to execute postInit()
. Any ideas how to do it in proper way?
CodePudding user response:
You can use a CountDownLatch
which is shared across all threads.
Some theory first: what is a CountDownLatch?
It's a very simple concurrent utility which you initialize with a certain integer, let's say N. It then offers you two methods:
countdown()
=> it will decrease toN-1
each time that is calledawait()
=> it will stop the current thread until when the count of the countdown is zero (you can specify a timeout if wished).
Of course, the great advantage of this class is that race conditions are handled for you (when you call countdown()
or await()
from a certain thread, you are guaranteed that other threads will see what's happening without you handling any memory barrier).
So now, based on your code, you start by making the preInit
and postInit
methods of Bean
taking a CountDownLatch
in parameter:
private static class Bean {
public void preInit(CountDownLatch latch) throws InterruptedException {
Thread.sleep(new Random().nextInt(1000) * 5);
System.out.println("PreInit " Thread.currentThread().getName());
latch.countDown(); //<-- each time one preInit ends, decrease the countdown by 1
}
public void postInit(CountDownLatch latch) throws InterruptedException {
latch.await(); //<-- even if you're called here, wait until when the countdown is at zero before starting execution
System.out.println("PostInit " Thread.currentThread().getName());
}
}
Specifically, the preInit
will count it down, while the postInit
will await
for it to be at zero before actually starting.
Then, in your calling function you create a new CountDownLatch(2)
(where 2
is the number of independent threads) and you simply push it down in the call stack:
public static void main(String[] args) {
Bean beanA = new Bean();
CountDownLatch latch = new CountDownLatch(2);
BeanRegister beanRegister1 = new BeanRegister(beanA, latch);
BeanRegister beanRegister2 = new BeanRegister(beanA, latch);
beanRegister1.start();
beanRegister2.start();
}
private static class BeanRegister extends Thread {
private final Bean bean;
private final CountDownLatch latch;
public BeanRegister(Bean bean, CountDownLatch latch) {
this.bean = bean;
this.latch = latch;
}
@Override
public void run() {
try {
bean.preInit(latch);
bean.postInit(latch);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
Sample output:
PreInit Thread-1
PreInit Thread-0
PostInit Thread-1
PostInit Thread-0
Process finished with exit code 0