Java多线程之CountDownLatch

ProjectDaedalus

共 6345字,需浏览 13分钟

 · 2021-11-30

这里就JUC包中的CountDownLatch类做相关介绍

abstract.jpeg

概述

JUC包中的CountDownLatch类是一个同步工具类,可实现线程间的通信。其典型方法如下所示

// 创建一个指定计数器值的CountDownLatch实例
public CountDownLatch(int count);

// 当前线程阻塞等待CountDownLatch实例的计数器值为0
public void await() throws InterruptedException;

// 支持超时的阻塞等待; 返回true: CountDownLatch实例的计数器值为0; 返回false: 超时
public boolean await(long timeout, TimeUnit unit);

// CountDownLatch实例的计数器值减1
public void countDown();

基本使用方法也很简单。首先创建一个指定计数器值的CountDownLatch实例,每当其他线程完成任务时就通过countDown方法将计数器值减1。这样当计数器的值为0时,之前由于调用await方法而被阻塞的线程就会结束等待,恢复执行

实践

CountDownLatch的典型应用场景,大体可分为两类:结束信号、开始信号

结束信号

主线程创建、启动N个异步任务,我们期望当这N个任务全部执行完毕结束后,主线程才可以继续往下执行。即将CountDownLatch作为任务的结束信号来使用。示例代码如下所示

public class CountDownLatchTest1 {

    @Test
    public void test1() throws InterruptedException {
        ExecutorService threadPool = Executors.newFixedThreadPool(5);
        CountDownLatch doneSignal = new CountDownLatch(3);

        Arrays.asList("Task 1","Task 2","Task 3")
                .stream()
                .map( name -> new Task(name, doneSignal) )
                .forEach( task -> threadPool.execute(task) );

        // 阻塞等待, 直到计数器变为0。即所有任务均完成
        doneSignal.await();
        System.out.println("所有任务均完成");
    }

    @AllArgsConstructor
    private static class Task implements Runnable{

        private String taskName;

        private CountDownLatch doneSignal;

        @Override
        public void run() {
            System.out.println(taskName + " 开始");
            // 模拟业务耗时
            try{
                Thread.sleep( RandomUtils.nextInt(5,9) * 1000 );
            }catch (Exception e) {
                System.out.println( "Happen Exception: " + e.getMessage());
            }
            System.out.println(taskName + " 完成");
            // 当前任务完成, 则计数器减一
            doneSignal.countDown();
        }
    }
}

测试结果如下所示,符合预期

figure 1.jpeg

开始信号

主线程创建N个异步任务,但这N个任务不能立即开始执行。而需要等待某个共同的前置任务(比如初始化任务)完成后,才允许这N个任务开始执行。即将CountDownLatch作为任务的开始信号来使用。示例代码如下所示

public class CountDownLatchTest2 {

    @Test
    public void test1() throws InterruptedException {
        ExecutorService threadPool = Executors.newFixedThreadPool(10);
        CountDownLatch startSignal = new CountDownLatch(1);

        Arrays.asList("Task 1","Task 2","Task 3")
                .stream()
                .map( name -> new Task(name, startSignal) )
                .forEach( task -> threadPool.execute(task) );

        // 执行初始化准备工作
        System.out.println("初始化准备工作开始");
        // 模拟业务耗时
        try{
            Thread.sleep( RandomUtils.nextInt(5,9) * 1000 );
        }catch (Exception e) {
            System.out.println( "Happen Exception: " + e.getMessage());
        }
        System.out.println("初始化准备工作结束");

        // 初始化准备工作完成, 则计数器减一
        startSignal.countDown();

        // 主线程等待所有任务执行完毕
        try{ Thread.sleep( 20*1000 ); } catch (Exception e) {}
        System.out.println("Game Over");
    }

    @AllArgsConstructor
    private static class Task implements Runnable{

        private String taskName;

        private CountDownLatch startSignal;

        @Override
        public void run() {
            try{
                // 阻塞等待, 直到计数器变为0。 即前置任务完成
                startSignal.await();
            }catch (InterruptedException e) {
                System.out.println( "Happen Exception: " + e.getMessage());
            }

            System.out.println(taskName + " 开始");
            // 模拟业务耗时
            try{
                Thread.sleep( RandomUtils.nextInt(5,9) * 1000 );
            }catch (Exception e) {
                System.out.println( "Happen Exception: " + e.getMessage());
            }
            System.out.println(taskName + " 完成");
        }
    }
}

测试结果如下所示,符合预期

figure 2.jpeg

基本原理

构造器

CountDownLatch类实现过程同样依赖于AQS。在构建CountDownLatch实例过程时,一方面,通过sync变量持有AQS的实现类Sync;另一方面,通过AQS的state字段来存储计数器值

public class CountDownLatch {
    private final Sync sync;

    public CountDownLatch(int count) {
        if (count < 0throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    private static final class Sync extends AbstractQueuedSynchronizer {
        Sync(int count) {
            setState(count);
        }   
    }
}

...

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {
    private volatile int state;

    protected final void setState(int newState) {
        state = newState;
    }
}

await方法

首先来看CountDownLatch的await方法。其委托sync调用AQS的acquireSharedInterruptibly方法,从方法名也可以看到其是对AQS中共享锁的使用。并根据当前计数器的值是否为0,来判断该线程是继续执行还是应该被阻塞。可以看到事实上AQS只是定义了是否需要阻塞线程的tryAcquireShared方法,具体的规则需要CountDownLatch类来进行实现

public class CountDownLatch {
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    private static final class Sync extends AbstractQueuedSynchronizer {   
        // 判断当前计数器值是否为0, 是则返回1; 否则返回-1
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }   
    }
}

...

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {
    public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
        // 线程被中断则直接抛出异常
        if (Thread.interrupted())
            throw new InterruptedException();
        
        if (tryAcquireShared(arg) < 0)
            // 当前计数器不为0, 需进入AQS的队列准备阻塞
            doAcquireSharedInterruptibly(arg);
    }
    
    // 需要子类去实现
    protected int tryAcquireShared(int arg) {
        throw new UnsupportedOperationException();
    }
}

当tryAcquireShared方法结果小于0时,即当前计数器不为0时,AQS如何通过doAcquireSharedInterruptibly方法实现阻塞呢?结合相关源码可以看到,首先通过addWaiter方法将当前线程包装为一个node实例,并将其加入AQS队列。在入队过程中需要注意,如果队列为空则其并不是直接将该node实例加入队列。而是先构造一个哨兵节点来入队,然后在enq方法下一轮for循环才将该node实例加入队列

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {
    private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
        // 将当前线程包装为node,加入AQS队列,并返回该node实例
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                // 获取 node 的前驱节点
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null// help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

    private Node addWaiter(Node mode) {
        // 将当前线程包装为一个node实例
        Node node = new Node(Thread.currentThread(), mode);
        Node pred = tail;
        // 队列的尾指针不为空, 说明队列不为空, 则利用尾插法将node入队
        if (pred != null) {
            node.prev = pred;
            if (compareAndSetTail(pred, node)) {
                pred.next = node;
                // 入队完毕, 直接返回该node
                return node;
            }
        }
        // 队列为空, 则先构建一个哨兵节点、入队,再将该node入队
        enq(node);
        return node;
    }

    private Node enq(final Node node) {
        for (;;) {
            Node t = tail;
            // 队尾指针为空, 则先进行队列的初始化
            if (t == null) { 
                // 构建一个哨兵节点并入队
                if (compareAndSetHead(new Node()))
                    tail = head;
            } else {
                // 将node入队 
                node.prev = t;
                if (compareAndSetTail(t, node)) {
                    t.next = node;
                    return t;
                }
            }
        }
    }
}

然后通过shouldParkAfterFailedAcquire方法修改前驱节点的waitStatus。如果前驱节点的waitStatus字段是初始值0的话,需在第一轮for循环中进入shouldParkAfterFailedAcquire方法时,通过compareAndSetWaitStatus(pred, ws, Node.SIGNAL)方法将前驱节点的waitStatus字段修改为Node.SIGNAL(即-1)。这样在开始下一轮for循环时,shouldParkAfterFailedAcquire方法即会返回true。进而执行parkAndCheckInterrupt方法,利用LockSupport.park完成线程阻塞

private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
    // 获取前驱节点的waitStatus字段值
    int ws = pred.waitStatus;
    if (ws == Node.SIGNAL)
        return true;
    if (ws > 0) {
        do {
            node.prev = pred = pred.prev;
        } while (pred.waitStatus > 0);
        pred.next = node;
    } else {
        compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
    }
    return false;
}

private final boolean parkAndCheckInterrupt() {
    LockSupport.park(this);
    return Thread.interrupted();
}

countDown方法

CountDownLatch的countDown方法类似。其同样是委托sync调用AQS的releaseShared方法。然后AQS执行tryReleaseShared方法,CountDownLatch类负责实现具体的规则逻辑。如果自减后当前计数器为0,则说明需要唤醒之前通过await方法而被阻塞的线程。然后通过AQS的doReleaseShared方法实现唤醒。具体地,其是从头节点的后继节点开始唤醒。因为前面已经说过,AQS队列的第一个节点(即头节点)只是一个哨兵节点

public class CountDownLatch {
    public void countDown() {
        sync.releaseShared(1);
    }

    private static final class Sync extends AbstractQueuedSynchronizer {
        protected boolean tryReleaseShared(int releases) {
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }  
    } 
}

...

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }
    
    // 需要子类去实现
    protected boolean tryReleaseShared(int arg) {
        throw new UnsupportedOperationException();
    }

    private void doReleaseShared() {
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    // 唤醒头节点的后继节点
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

    private void unparkSuccessor(Node node) {
        int ws = node.waitStatus;
        if (ws < 0)
            compareAndSetWaitStatus(node, ws, 0);
        Node s = node.next;
        if (s == null || s.waitStatus > 0) {
            s = null;
            for (Node t = tail; t != null && t != node; t = t.prev)
                if (t.waitStatus <= 0)
                    s = t;
        }
        if (s != null)
            LockSupport.unpark(s.thread);
    }
}

这里补充说明下,当上文由于调用await方法而被阻塞的线程唤醒后,其会在doAcquireSharedInterruptibly方法的for循环中恢复执行。此时由于tryAcquireShared方法的返回值r大于0满足条件,故其进入setHeadAndPropagate方法。在该方法中,其将自身重新设置为AQS的头节点。并通过doReleaseShared方法继续唤醒它的后继节点。从而实现将AQS队列被阻塞的线程全部唤醒

private void setHeadAndPropagate(Node node, int propagate) {
    Node h = head; // Record old head for check below
    setHead(node);

    if (propagate > 0 || h == null || h.waitStatus < 0 ||
        (h = head) == null || h.waitStatus < 0) {
        Node s = node.next;
        if (s == null || s.isShared())
            doReleaseShared();
    }
}

Note

CountDownLatch的计数器值只能在创建实例时进行设置,之后不可以对其进行重新设置。换言之,CountDownLatch是一次性的,当其使用完毕后将无法再次利用

参考文献

  1. Java并发编程之美 翟陆续、薛宾田著
浏览 21
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报