(手机横屏看源码更方便)java
注:java源码分析部分如无特殊说明均基于 java8 版本。程序员
注:本文基于ForkJoinPool分治线程池类。面试
随着在硬件上多核处理器的发展和普遍使用,并发编程成为程序员必须掌握的一门技术,在面试中也常常考查面试者并发相关的知识。算法
今天,咱们就来看一道面试题:编程
如何充分利用多核CPU,计算很大数组中全部整数的和?数组
咱们最容易想到就是单线程相加,一个for循环搞定。并发
若是进一步优化,咱们会天然而然地想到使用线程池来分段相加,最后再把每一个段的结果相加。框架
Yes,就是咱们今天的主角——ForkJoinPool,可是它要怎么实现呢?彷佛没怎么用过哈^^dom
OK,剖析完了,咱们直接来看三种实现,不墨迹,直接上菜。ide
/** * 计算1亿个整数的和 */ public class ForkJoinPoolTest01 { public static void main(String[] args) throws ExecutionException, InterruptedException { // 构造数据 int length = 100000000; long[] arr = new long[length]; for (int i = 0; i < length; i++) { arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE); } // 单线程 singleThreadSum(arr); // ThreadPoolExecutor线程池 multiThreadSum(arr); // ForkJoinPool线程池 forkJoinSum(arr); } private static void singleThreadSum(long[] arr) { long start = System.currentTimeMillis(); long sum = 0; for (int i = 0; i < arr.length; i++) { // 模拟耗时,本文由公从号“彤哥读源码”原创 sum += (arr[i]/3*3/3*3/3*3/3*3/3*3); } System.out.println("sum: " + sum); System.out.println("single thread elapse: " + (System.currentTimeMillis() - start)); } private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException { long start = System.currentTimeMillis(); int count = 8; ExecutorService threadPool = Executors.newFixedThreadPool(count); List<Future<Long>> list = new ArrayList<>(); for (int i = 0; i < count; i++) { int num = i; // 分段提交任务 Future<Long> future = threadPool.submit(() -> { long sum = 0; for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) { try { // 模拟耗时 sum += (arr[j]/3*3/3*3/3*3/3*3/3*3); } catch (Exception e) { e.printStackTrace(); } } return sum; }); list.add(future); } // 每一个段结果相加 long sum = 0; for (Future<Long> future : list) { sum += future.get(); } System.out.println("sum: " + sum); System.out.println("multi thread elapse: " + (System.currentTimeMillis() - start)); } private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException { long start = System.currentTimeMillis(); ForkJoinPool forkJoinPool = ForkJoinPool.commonPool(); // 提交任务 ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arr, 0, arr.length)); // 获取结果 Long sum = forkJoinTask.get(); forkJoinPool.shutdown(); System.out.println("sum: " + sum); System.out.println("fork join elapse: " + (System.currentTimeMillis() - start)); } private static class SumTask extends RecursiveTask<Long> { private long[] arr; private int from; private int to; public SumTask(long[] arr, int from, int to) { this.arr = arr; this.from = from; this.to = to; } @Override protected Long compute() { // 小于1000的时候直接相加,可灵活调整 if (to - from <= 1000) { long sum = 0; for (int i = from; i < to; i++) { // 模拟耗时 sum += (arr[i]/3*3/3*3/3*3/3*3/3*3); } return sum; } // 分红两段任务,本文由公从号“彤哥读源码”原创 int middle = (from + to) / 2; SumTask left = new SumTask(arr, from, middle); SumTask right = new SumTask(arr, middle, to); // 提交左边的任务 left.fork(); // 右边的任务直接利用当前线程计算,节约开销 Long rightResult = right.compute(); // 等待左边计算完毕 Long leftResult = left.join(); // 返回结果 return leftResult + rightResult; } } }
彤哥偷偷地告诉你,实际上计算1亿个整数相加,单线程是最快的,个人电脑大概是100ms左右,使用线程池反而会变慢。
因此,为了演示ForkJoinPool的牛逼之处,我把每一个数都/3*3/3*3/3*3/3*3/3*3
了一顿操做,用来模拟计算耗时。
来看结果:
sum: 107352457433800662 single thread elapse: 789 sum: 107352457433800662 multi thread elapse: 228 sum: 107352457433800662 fork join elapse: 189
能够看到,ForkJoinPool相对普通线程池仍是有很大提高的。
问题:普通线程池可否实现ForkJoinPool这种计算方式呢,即大任务拆中任务,中任务拆小任务,最后再汇总?
你能够试试看(-᷅_-᷄)
OK,下面咱们正式进入ForkJoinPool的解析。
把一个规模大的问题划分为规模较小的子问题,而后分而治之,最后合并子问题的解获得原问题的解。
(1)分割原问题:
(2)求解子问题:
(3)合并子问题的解为原问题的解。
在分治法中,子问题通常是相互独立的,所以,常常经过递归调用算法来求解子问题。
(1)二分搜索
(2)大整数乘法
(3)Strassen矩阵乘法
(4)棋盘覆盖
(5)归并排序
(6)快速排序
(7)线性时间选择
(8)汉诺塔
ForkJoinPool是 java 7 中新增的线程池类,它的继承体系以下:
ForkJoinPool和ThreadPoolExecutor都是继承自AbstractExecutorService抽象类,因此它和ThreadPoolExecutor的使用几乎没有多少区别,除了任务变成了ForkJoinTask之外。
这里又运用到了一种很重要的设计原则——开闭原则——对修改关闭,对扩展开放。
可见整个线程池体系一开始的接口设计就很好,新增一个线程池类,不会对原有的代码形成干扰,还能利用原有的特性。
fork()方法相似于线程的Thread.start()方法,可是它不是真的启动一个线程,而是将任务放入到工做队列中。
join()方法相似于线程的Thread.join()方法,可是它不是简单地阻塞线程,而是利用工做线程运行其它任务。当一个工做线程中调用了join()方法,它将处理其它任务,直到注意到目标子任务已经完成了。
无返回值任务。
有返回值任务。
无返回值任务,完成任务后能够触发回调。
ForkJoinPool内部使用的是“工做窃取”算法实现的。
(1)每一个工做线程都有本身的工做队列WorkQueue;
(2)这是一个双端队列,它是线程私有的;
(3)ForkJoinTask中fork的子任务,将放入运行该任务的工做线程的队头,工做线程将以LIFO的顺序来处理工做队列中的任务;
(4)为了最大化地利用CPU,空闲的线程将从其它线程的队列中“窃取”任务来执行;
(5)从工做队列的尾部窃取任务,以减小竞争;
(6)双端队列的操做:push()/pop()仅在其全部者工做线程中调用,poll()是由其它线程窃取任务时调用的;
(7)当只剩下最后一个任务时,仍是会存在竞争,是经过CAS来实现的;
(1)最适合的是计算密集型任务,本文由公从号“彤哥读源码”原创;
(2)在须要阻塞工做线程时,可使用ManagedBlocker;
(3)不该该在RecursiveTask<r>的内部使用ForkJoinPool.invoke()/invokeAll();
(1)ForkJoinPool特别适合于“分而治之”算法的实现;
(2)ForkJoinPool和ThreadPoolExecutor是互补的,不是谁替代谁的关系,两者适用的场景不一样;
(3)ForkJoinTask有两个核心方法——fork()和join(),有三个重要子类——RecursiveAction、RecursiveTask和CountedCompleter;
(4)ForkjoinPool内部基于“工做窃取”算法实现;
(5)每一个线程有本身的工做队列,它是一个双端队列,本身从队列头存取任务,其它线程从尾部窃取任务;
(6)ForkJoinPool最适合于计算密集型任务,但也可使用ManagedBlocker以便用于阻塞型任务;
(7)RecursiveTask内部能够少调用一次fork(),利用当前线程处理,这是一种技巧;
ManagedBlocker怎么使用?
答:ManagedBlocker至关于明确告诉ForkJoinPool框架要阻塞了,ForkJoinPool就会启另外一个线程来运行任务,以最大化地利用CPU。
请看下面的例子,本身琢磨哈^^。
/** * 斐波那契数列 * 一个数是它前面两个数之和 * 1,1,2,3,5,8,13,21 */ public class Fibonacci { public static void main(String[] args) { long time = System.currentTimeMillis(); Fibonacci fib = new Fibonacci(); int result = fib.f(1_000).bitCount(); time = System.currentTimeMillis() - time; System.out.println("result,本文由公从号“彤哥读源码”原创 = " + result); System.out.println("test1_000() time = " + time); } public BigInteger f(int n) { Map<Integer, BigInteger> cache = new ConcurrentHashMap<>(); cache.put(0, BigInteger.ZERO); cache.put(1, BigInteger.ONE); return f(n, cache); } private final BigInteger RESERVED = BigInteger.valueOf(-1000); public BigInteger f(int n, Map<Integer, BigInteger> cache) { BigInteger result = cache.putIfAbsent(n, RESERVED); if (result == null) { int half = (n + 1) / 2; RecursiveTask<BigInteger> f0_task = new RecursiveTask<BigInteger>() { @Override protected BigInteger compute() { return f(half - 1, cache); } }; f0_task.fork(); BigInteger f1 = f(half, cache); BigInteger f0 = f0_task.join(); long time = n > 10_000 ? System.currentTimeMillis() : 0; try { if (n % 2 == 1) { result = f0.multiply(f0).add(f1.multiply(f1)); } else { result = f0.shiftLeft(1).add(f1).multiply(f1); } synchronized (RESERVED) { cache.put(n, result); RESERVED.notifyAll(); } } finally { time = n > 10_000 ? System.currentTimeMillis() - time : 0; if (time > 50) System.out.printf("f(%d) took %d%n", n, time); } } else if (result == RESERVED) { try { ReservedFibonacciBlocker blocker = new ReservedFibonacciBlocker(n, cache); ForkJoinPool.managedBlock(blocker); result = blocker.result; } catch (InterruptedException e) { throw new CancellationException("interrupted"); } } return result; // return f(n - 1).add(f(n - 2)); } private class ReservedFibonacciBlocker implements ForkJoinPool.ManagedBlocker { private BigInteger result; private final int n; private final Map<Integer, BigInteger> cache; public ReservedFibonacciBlocker(int n, Map<Integer, BigInteger> cache) { this.n = n; this.cache = cache; } @Override public boolean block() throws InterruptedException { synchronized (RESERVED) { while (!isReleasable()) { RESERVED.wait(); } } return true; } @Override public boolean isReleasable() { return (result = cache.get(n)) != RESERVED; } } }
欢迎关注个人公众号“彤哥读源码”,查看更多源码系列文章, 与彤哥一块儿畅游源码的海洋。