Java多线程编程-Fork/Join

Fork/Join 简介

Fork/Join 框架在 JDK1.7 引入,给我的感觉主要解决类似递归类任务,分而治之,最适合的是计算密集型任务。将一个大的任务拆分成多个小的任务,多个任务之间又有一些的关联和依赖关系,并行执行提高 CPU 的使用率。

ForkJoinPool采用分治+work-stealing的思想。可以让我们很方便地将一个大任务拆散成小任务,并行地执行,提高CPU的使用率。

我们通过 Fork 和 Join 这两个单词来理解下 Fork/Join 框架,Fork 就是把一个大任务切分为若干子任务并行的执行,Join 就是合并这些子任务的执行结果,最后得到这个大任务的结果。比如计算 1+2+。。+10000,可以分割成 10 个子任务,每个子任务分别对 1000 个数进行求和,最终汇总这 10 个子任务的结果。Fork/Join 的运行流程图如下

image

工作窃取算法

work-stealing 算法翻译过来也就是工作窃取算法,比如 ForkJoinPool 基本思想如下:

image

  • ForkJoinPool 的每个工作线程都会维护一个工作队列,这个队列是一个双端队列,里面存放的对象就是任务。
  • 每个工作线程在运行中产生的新的任务(比如我让一个线程算1+…+1000,它分别让两个线程来帮他算,一个算1+…+500,一个算501+…+1000,自己等结果,这就产生了新的任务,也就是在线程里面 fork() 的时候),会把任务放在队列的末尾,并且工作线程在处理任务的时候,采用先进后出(LIFO)的方式,也就是先处理队尾的任务来执行。
  • 每个工作线程在处理自己工作队列的同时,会尝试窃取一个任务,窃取的位置位于其他工作队列的队首,也就是工作线程窃取其他工作线程的任务采用 FIFO 方式。
  • 遇到 join() 时,如果需要 join() 的任务尚未完成,则会先处理其他任务,并等待其完成。
  • 即没有自己的任务也没有可以窃取的任务时,进入休眠。

总结来说:当我们把一个大的任务分割成多个子任务,把这些子任务放到不同的队列里面去,为每个队列创建一个线程来执行队列里面的任务,线程和队列一一对应,但是有些任务简单啊,比如这个 A 队列里面的任务就先全部被处理完了,但是 B 队列还剩了一大堆活,A 队列的对应的线程 A 就看不下去了,我不能闲着,它就去窃取一个 B 队列里面的任务,所以在使用双端队列的情况下就十分的方便,B 线程取 B队列结尾的任务来干,A 队列取 队首的任务来干。

工作窃取算法的优点就在于了充分利用线程进行并行计算,并且减少了线程间的竞争,但是也有缺点,就是在某些情况下还是存在竞争的情况,比如双端队列只有一个任务的时候。并且也消耗了更多的系统资源,比如创建多个线程和多个双端队列。

Fork/Join 框架的使用

Fork/Join 经过上面的介绍,我们大概也知道它最适合的就是用来大任务分割并行处理。

第一步就是分割任务,首先需要有一个 Fork 类来把大任务分割成小任务,有些子任务还是还不是足够的小,所以可能还需要继续的分割。

第二步就是子任务执行结果的合并,分割的子任务分别放在双端队列里,然后几个启动线程分别从双端队列里获取任务执行。子任务执行完的结果都统一放在一个队列里,启动一个线程从队列里拿数据,然后合并这些数据。

对应 Java 中的 Fork/Join 使用如下的类:

  • ForkJoinTask: 我们要创建 ForkJoin 任务,提供在任务中进行 fork()join() 操作的机制,通常情况下我们只需要使用 ForkJoinTask 的子类。
    • RecursiveAction:用于没有返回结果的任务。
    • RecursiveTask :用于有返回结果的任务。
  • ForkJoinPoolForkJoinTask 需要通过 ForkJoinPool 来执行,任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务。

我们实现一个简单的需求,比如我现在需要计算 1+2+…+100,我们使用 Fork/Join 框架来做这个事。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
public class ForkJoinTest {

private ForkJoinPool forkJoinPool;

//一个有返回结果的任务
private static class SumTask extends RecursiveTask<Long> {

//需要计算的数字数组,这里数组是1-100的所有数字
private long[] numbers;
private int from;
private int to;

public SumTask(long[] numbers, int from, int to) {
this.numbers = numbers;
this.from = from;
this.to = to;
}

@Override
protected Long compute() {
//比如 5 - 1,相差小于6,就直接算不继续分解为子任务了
if (to -from < 6) {
long total = 0;
for(int i = from; i <= to; i++) {
total += i;
}
return total;
} else {
// 计算的数字太多,平均拆分成两个子任务
int middle = (from + to) / 2;
SumTask leftTask = new SumTask(numbers, from, middle);
SumTask rightTask = new SumTask(numbers, middle + 1, to);
leftTask.fork();
rightTask.fork();
return leftTask.join() + rightTask.join();
}
}
}

public ForkJoinTest() {
// 也可以使用公用的 ForkJoinPool:
// pool = ForkJoinPool.commonPool()
forkJoinPool = new ForkJoinPool();
}

public long sumUp(long[] numbers) {
return forkJoinPool.invoke(new SumTask(numbers, 0, numbers.length-1));
}

public static void main(String[] args) {
long[] numbers = LongStream.rangeClosed(1, 50).toArray();
System.out.println(new ForkJoinTest().sumUp(numbers));
}

}

计算的结果自然是正确的,通过这个例子让我们再来进一步了解 ForkJoinTaskForkJoinTask 与一般的任务的主要区别在于它需要实现 compute 方法,在这个方法里,首先需要判断任务是否足够小,如果足够小就直接执行任务。如果不足够小,就必须分割成两个子任务,每个子任务在调用 fork 方法时,又会进入 compute 方法,看看当前子任务是否需要继续分割成孙任务,如果不需要继续分割,则执行当前子任务并返回结果。使用 join 方法会等待子任务执行完并得到其结果。

fork() 和 join() 的作用

  • fork(): 开启一个新线程(或是重用线程池内的空闲线程),将任务交给该线程处理。
  • join(): 等待该任务的处理线程处理完毕,获得返回值。

Fork/Join 框架的异常处理

ForkJoinTask 任务可能会在执行的时候抛出异常, ForkJoinTask 提供了 isCompletedAbnormally() 方法来检查任务是否抛出异常或者已经被取消了,并且可以使用 getException() 方法获取异常。

getException 方法返回 Throwable 对象,如果任务被取消了则返回 CancellationException。如果任务没有完成或者没有抛出异常则返回 null。

1
2
3
if (leftTask.isCompletedAbnormally()) {
System.out.println(leftTask.getException());
}

Fork/Join 框架实现的原理

Fork/Join Framework 的原理首先可以看下 Doug Lea 的论文 《A Java Fork/Join Framework》,Doug Lea,JUC 包下面的类。。好像都是他写的。

上面简单介绍过工作窃取算法,根据 ForkJoinPool 的源代码发现,默认的工作线程数是 Runtime.getRuntime().availableProcessors(),会来取当前宿主机的核心数来作为线程数,但是像我们上面那样开了那么多的 ForkJoinTask,肯定就会有很多会被放到队列里面去,等到某个线程忙完了来帮忙。

fork

fork() 做的工作只有一件事,既是把任务推入当前工作线程的工作队列里。可以参看以下的源代码:

1
2
3
4
5
6
7
8
public final ForkJoinTask<V> fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}

join

join() 调用了 doJoin() 方法,通过 doJoin() 方法得到当前任务的状态来判断返回什么结果,任务状态有四种:已完成(NORMAL),被取消(CANCELLED),信号(SIGNAL)和出现异常(EXCEPTIONAL)

  • 如果任务状态是已完成,则直接返回任务结果。
  • 如果任务状态是被取消,则直接抛出 CancellationException。
  • 如果任务状态是抛出异常,则直接抛出对应的异常。
1
2
3
4
5
6
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}

让我们再来分析下 doJoin() 方法的实现代码

在 doJoin() 方法里,首先通过查看任务的状态,看任务是否已经执行完了,如果执行完了,则直接返回任务状态,如果没有执行完,则从任务数组里取出任务并执行。如果任务顺利执行完成了,则设置任务状态为 NORMAL,如果出现异常,则纪录异常,并将任务状态设置为 EXCEPTIONAL

1
2
3
4
5
6
7
8
9
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
return (s = status) < 0 ? s :
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
tryUnpush(this) && (s = doExec()) < 0 ? s :
wt.pool.awaitJoin(w, this, 0L) :
externalAwaitDone();
}

整体流程图如下:

  1. 检查调用 join() 的线程是否是 ForkJoinThread 线程。如果不是(例如 main 线程),则阻塞当前线程,等待任务完成。如果是,则不阻塞。
  2. 查看任务的完成状态,如果已经完成,直接返回结果。
  3. 如果任务尚未完成,但处于自己的工作队列内,则完成它。
  4. 如果任务已经被其他的工作线程偷走,则窃取这个小偷的工作队列内的任务(以 FIFO 方式),执行,以期帮助它早日完成欲 join 的任务。
  5. 如果偷走任务的小偷也已经把自己的任务全部做完,正在等待需要 join 的任务时,则找到小偷的小偷,帮助它完成它的任务。
  6. 递归地执行第5步。

image

上面的大概就是 fork(), join() 大概的执行逻辑,但是最开始的任务会被 push 到哪个线程的工作队列里面去的?

这个就需要去看看 submit() 方法是怎么实现的了。

submit

其实除了前面介绍过的每个工作线程自己拥有的工作队列以外,ForkJoinPool 自身也拥有工作队列,这些工作队列的作用是用来接收由外部线程(非 ForkJoinThread 线程)提交过来的任务,而这些工作队列被称为 submitting queue 。

submit() 和 fork() 其实没有本质区别,只是提交对象变成了 submitting queue 而已(还有一些同步,初始化的操作)。submitting queue 和其他 work queue 一样,是工作线程”窃取“的对象,因此当其中的任务被一个工作线程成功窃取时,就意味着提交的任务真正开始进入执行阶段。

更多

在 JDK8 中 lamdba 有个 stream 操作 parallelStream,底层也是使用ForkJoinPool实现的;

我们可以通过 Executors.newWorkStealingPool(int parallelism) 快速创建 ForkJoinPool 线程池,无参默认使用CPU数量的线程数执行任务;