《Java7并发编程实战手册》学习笔记(六)——Fork/Join框架

此篇博客为我的学习笔记,若有错误欢迎你们指正

本次内容

  1. 建立Fork/Join线程池
  2. 合并任务的结果
  3. 异步运行任务
  4. 在任务中抛出异常
  5. 取消任务

Java为咱们提供了ExecutorService接口的另外一种实现——Fork/Join框架(分解/合并框架),这个框架能帮助咱们更简单的用分治技术解决问题。使用Fork/Join框架,在执行一个任务时,咱们首先判断这个任务的规模是否大于咱们制定的标准,若是大于就将这个任务分解(fork)为规模更小的任务去执行;最后再将执行完成小任务层层合并(join)为大任务并返回,原理图以下:java

Fork/Join框架与咱们以前使用的执行器框架的主要区别在于前者实现了 工做窃取算法。当咱们使用 join()方法使一个主任务等待它所建立的子任务完成时,执行任务的线程(工做者线程)并不会由于等待其余任务的完成而进入休眠状态,而是随机的去其余线程所维护的双端队列末尾取出一个任务来执行,这就极大的提高了工做效率。固然,为了达到上述目标,在使用Fork/Join框架时有如下限制:

  • 任务只能使用fork()join()等一些专门为Fork/Join框架准备的方法进行同步。若是使用了其余同步机制,工做者线程会真正的进入阻塞状态而且不会窃取其余线程的任务来执行
  • 任务不能执行I/O操做
  • 任务不能够抛出非运行时异常

Fork/Join框架的核心是由如下两个类组成的:算法

  • ForkJoinPool:这个类也实现了ExecutorExecutorService接口,和咱们以前使用过的ThreadPoolExecutor类有些相似,主要区别在于这个类实现了工做窃取算法。获取ForkJoinPool对象的方法主要有如下几种,咱们能够根据不一样的需求进行选择:数组

    1. 首先是ForkJoinPool类的构造方法:app

      • ForkJoinPool():无参构造方法,调用此方法得到的ForkJoinPool对象将执行默认的配置。其并行级别为当前JVM能够调用的CPU内核数量
      • ForkJoinPool(int parallelism):经过这个构造方法能够指定线程池的并行级别,可是咱们传入的参数应该是大于0且小于等于JVM能够调用的CPU内核数量的
      • ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode) 此方法参数较多,下面会分别记录:
        1. parallelism:并行级别。Fork/Join框架将根据这个参数来设定框架内并行执行的线程数。注意,并非框架中最大的线程数量
        2. factory:线程工厂。咱们能够编写本身的Fork/Join县城工厂类,和之构造的线程工厂不一样。构造Fork/Join线程工厂类须要咱们实现ForkJoinWorkerThreadFactory接口而不是ThreadFactory接口
        3. handler:异常捕获处理器。当执行中的任务向上抛出异常时,就会被处理器捕获
        4. asyncMode:工做模式。Fork/Join框架中的每个工做线程都维护着一个双端队列用于装载任务,参数为true则表示队列中的任务先进先出(FIFO),为false则表示后进先出(LIFO)
    2. ForkJoinPool类的静态方法commonPool()一样能够得到ForkJoinPool对象。值得注意的是,调用此方法得到的是Java预约义的线程池,这能够减小资源的消耗由于咱们再也不须要每提交一个任务就建立一个新的线程池了,也就是说每次咱们调用此方法所得到的对象引用实际上都指向同一个线程池,能够发现执行下面的代码打印出的值将为true框架

      ForkJoinPool forkJoinPool1 = ForkJoinPool.commonPool();
          ForkJoinPool forkJoinPool2 = ForkJoinPool.commonPool();
          System.out.println(forkJoinPool1 == forkJoinPool2);
      复制代码

      另外,在调用经此方法得到的ForkJoinPool对象的shutdown()方法时,线程池并不会关闭dom

    3. 使用Executors类的静态方法newWorkStealingPool()或者此方法的另外一种实现newWorkStealingPool(int parallelism)异步

  • ForkJoinTask:此类实现了Future接口,是在ForkJoinPool中执行的任务的基类。为了使用Fork/Join框架执行任务,一般状况下咱们须要实现如下两个ForkJoinTask子类的其中一个async

    • RecursiveAciton:用于任务没有返回结果的场景
    • RecursiveTask:用于任务有返回结果的场景

    在继承上面两个类后,咱们最好在本身的类中加上这样一个属性:
    private static final long serialVersionUID = 1L;
    这是由于RecursiveActionRecursiveTask类均继承了ForkJoinTask类,而ForkJoinTask类又实现了Serializable接口。若是咱们不显示的声明这个属性,那么Java会根据当前类的属性、方法给出一个默认值。当咱们修改了类的属性或方法后,这个值会发生变化。这样一来,咱们在将修改以前进行过序列化的类进行反序列化时就会出现错误。因此咱们最好显示的声明这一属性。ide

1.建立Fork/Join线程池

使用Fork/Join框架,咱们最好参考JavaAPI手册为咱们推荐的代码结构学习

if (problem size > default size) {
    tasks = divide(task);
    execute(tasks);
} else {
    resolve problem using another algorthm;
}
复制代码

下面是在此小节中须要了解的方法:

  • ForkJoinPool类:
    1. execute(ForkJoinTask<?> task):无返回值。调用此方法向线程池提交一个任务,注意这个方法是异步的,调用后线程不会等待而是直接向下执行。execute(Runnable task)是另外一种实现,提交一个Runnable类型的任务给线程池,在这种状况下线程池不会使用工做窃取算法
    2. invoke(ForkJoinTask<T> task):此方法最好和execute(ForkJoinTask<?> task)方法对比来看。区别在与这个方法是同步的,调用后会直到任务执行结束后才返回。返回值即为任务返回的结果
    3. 由于ForkJoinPool类实现了ExecutorService接口,因此也实现了invokeAll()invokeAny()方法。这些方法以前都已经使用过,参数为Callable类型的任务列表。可是当咱们向ForkJoinPool发送Runnable或Callable类型的任务时,线程池并不会使用工做窃取算法,所以咱们不推荐这样作
  • ForkJoinTask类:
    1. adapt():传入一个Runnable或Callable对象,返回一个ForkJoinTask对象
    2. invokeAll():传入ForkJoinTask对象列表或数个ForkJoinTask对象。这个方法是同步的,当主任务在等待子任务时,执行主任务的工做线程会开始执行另外一个等待执行的任务。值得注意的是,因传入参数不一样这个方法的返回值也有所区别。直接传入ForkJoinTask对象的话此方法没有返回值;传入ForkJoinTask对象列表的话返回值也为传入的ForkJoinTask对象列表,而且通过调试咱们能够发现传入和返回的两个列表对象的引用实际是指向同一个对象

范例实现

在这个范例中,咱们将对全部商品使用分治技术进行涨价操做。因为任务不须要有返回值,咱们的任务类继承了RecursiveAciton
商品类:

package day06.code_1;

public class Product {

    //商品名称
    private String name;

    //商品价格
    private double price;

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public double getPrice() {
        return price;
    }

    public void setPrice(double price) {
        this.price = price;
    }
}
复制代码

商品列表生成类:

package day06.code_1;

import java.util.ArrayList;
import java.util.List;

public class ProductListGenerator {

    //根据传入的大小建立一个产品集合
    public List<Product> generate(int size) {
        //建立一个集合
        ArrayList<Product> products = new ArrayList<>();
        for (int i = 0; i < size; i++) {
            //建立产品
            Product product = new Product();
            //设置名字
            product.setName("Product " + i);
            //统一设置初始价格为10,方便检查程序的正确性
            product.setPrice(10);
            //装入集合
            products.add(product);
        }
        //返回集合
        return products;
    }

}
复制代码

任务类:

package day06.code_1;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.RecursiveAction;

public class Task extends RecursiveAction {

    //必备参数
    private static final long serialVersionUID = 1L;

    //产品集合
    private List<Product> products;

    //起始位置
    private int first;

    //终止位置
    private int last;

    //价格增百分比
    private double increment;

    public Task(List<Product> products, int first, int last, double increment) {
        this.products = products;
        this.first = first;
        this.last = last;
        this.increment = increment;
    }

    @Override
    protected void compute() {
        //若是任务数量小于10
        if (last - first < 10) {
            //执行涨价操做
            updatePrices();
        } else {
            //若是任务数量大于10则将任务均分
            int middle = (first + last) / 2;
            //打印分割任务提示语
            System.out.printf("Task: Pending tasks:%s\n",
                    getQueuedTaskCount());
            //根据新分配的范围建立两个任务
            Task t1 = new Task(products, first, middle + 1, increment);
            Task t2 = new Task(products, middle + 1, last, increment);
            //执行
            invokeAll(t1, t2);
        }
    }

    private void updatePrices() {
        //遍历集合为每个商品作涨价操做
        for (int i = first; i < last; i++) {
            Product product = products.get(i);
            product.setPrice(product.getPrice() * (1 + increment));
        }
    }
}
复制代码

main方法:

package day06.code_1;

import java.util.List;
import java.util.concurrent.ForkJoinPool;

public class Main {

    public static void main(String[] args) {
        //建立产品生成对象
        ProductListGenerator generator = new ProductListGenerator();
        //经过产品生成器获得大小为10000的产品集合
        List<Product> products = generator.generate(10000);
        //建立一个任务
        Task task = new Task(products, 0, 10000, 0.20);
        //建立线程池
        ForkJoinPool pool = new ForkJoinPool();
        //调用线程池的方法执行任务
        pool.execute(task);
        do {
            //打印线程池中当前正在执行任务的线程数量
            System.out.printf("Main: Thread Count: %d\n",
                    pool.getActiveThreadCount());
            //打印线程池中窃取的工做数量
            System.out.printf("Main: Thread Steal: %d\n",
                    pool.getStealCount());
            //打印线程池的并行级别
            System.out.printf("Main: Parallelism: %d\n",
                    pool.getParallelism());
            //休眠5秒
            try {
                Thread.sleep(5);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            //等待任务结束
        } while (!task.isDone());
        //关闭线程池
        pool.shutdown();
        //判断任务是否抛出了异常
        if (task.isCompletedNormally()) {
            //打印任务无异常完成的提示信息
            System.out.printf("Main: The process has completed normally\n");
        }
        //检查商品是否已正确涨价
        for (int i = 0; i < products.size(); i++) {
            Product product = products.get(i);
            if (product.getPrice() != 12) {
                System.out.printf("Product %s: %f\n",
                        product.getName(), product.getPrice());
            }
        }
        //打印程序结束提示语
        System.out.println("Main: End of the program\n");
    }

}
复制代码

2.合并任务的结果

使用Fork/Join框架执行带有返回值的任务时必须继承RecursiveTask类并使用JavaAPI文档推荐的结构:

if (problem size > default size) {
    tasks = Divide(task);
    execute(tasks);
    groupResults();
    return results;
} else {
    resolve problem;
    return result;
}
复制代码

如下几个ForkJoinTask类中的方法咱们须要了解:

  1. fork():无参数、返回值。此方法用于向线程池异步的发送一个任务,发送完成后将会马上返回并向下执行
  2. get():一直等待直到得到任务返回的结果。另外一种实现为get(long timeout, TimeUnit unit),若是等待时间超时后任务还未返回结果,则方法直接返回null。get方法能够被中断。若是任务抛出运行时异常,get方法会返回ExecutionException异常
  3. join():一直等待直到得到任务返回的结果。此方法和get()方法有些相似,区别在于join()方法不能被中断。若是中断调用了该方法的线程,join()方法将抛出InterruptedException异常。另外,任务抛出运行时异常时,join()方法会返回RuntimeWxception异常
    以上三个方法中,第一个与第二或三个方法组合常常用来实现异步运行任务这一需求

范例实现

在这个范例中,咱们将统计一个指定词汇在文档中出现的次数。咱们会不断切割任务直到每一个任务仅搜索100个之内的词汇
DocumentMock(文档生成类):

package day06.code_2;


import java.util.Random;

public class DocumentMock {

    //从如下词汇中选择词语组成文档
    private String words[] = {
            "the", "hello", "goodbye", "packt", "java",
            "thread", "pool", "random", "class", "main"
    };

    public String[][] generateDocument(int numLines, int numWords, String word) {
        //记录指定词汇出现的次数,便于后期判断程序对错
        int counter = 0;
        //建立二维数组
        String[][] document = new String[numLines][numWords];
        //随机数生成器
        Random random = new Random();
        //填充数组
        for (int i = 0; i < numLines; i++) {
            for (int j = 0; j < numWords; j++) {
                //随机选取词汇并填充
                int index = random.nextInt(words.length);
                document[i][j] = words[index];
                //若是是指定词汇,计数器加一
                if (document[i][j] == word) {
                    counter++;
                }

            }
        }
        //打印指定词汇出现的次数
        System.out.printf("DocumentMock: The word appears " +
                "%d times in the document\n", counter);
        //返回文档
        return document;
    }

}
复制代码

DocumentTask(文档任务类):

package day06.code_2;

import java.util.ArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.RecursiveTask;

public class DocumentTask extends RecursiveTask<Integer> {

    //必备参数
    private static final long serialVersionUID = 1L;

    //文档
    private String[][] document;

    //起始、结束位置
    private int start, end;

    //待查找的词汇
    private String word;

    public DocumentTask(String[][] document, int start, int end, String word) {
        this.document = document;
        this.start = start;
        this.end = end;
        this.word = word;
    }

    @Override
    protected Integer compute() {
        //初始化计数器
        int result = 0;
        //若是行数小于10
        if (end - start < 10) {
            //处理每一行的数据
            result = processLines(document, start, end, word);
        } else {
            //行数大于10则进行任务分割
            int mid = (start + end) / 2;
            DocumentTask task1 = new DocumentTask(document, start, mid, word);
            DocumentTask task2 = new DocumentTask(document, mid, end, word);
            //提交任务(同步)
            invokeAll(task1, task2);
            try {
                //处理子任务返回的结果
                result = groupResults(task1.get(), task2.get());
            } catch (ExecutionException | InterruptedException e) {
                e.printStackTrace();
            }
        }
        //返回结果
        return result;
    }

    //将子任务结果相加后返回
    private int groupResults(Integer number1, Integer number2) {
        return number1 + number2;
    }

    private int processLines(String[][] document, int start, int end, String word) {
        //建立装载行任务的集合
        ArrayList<LineTask> tasks = new ArrayList<>();
        //建立行任务
        for (int i = start; i < end; i++) {
            LineTask task = new LineTask(document[i], 0, document[i].length, word);
            tasks.add(task);
        }
        //执行全部任务
        invokeAll(tasks);
        //初始化计数器
        int result = 0;
        //从任务中获取结果
        for (int i = 0; i < tasks.size(); i++) {
            LineTask task = tasks.get(i);
            try {
                result = result + task.get();
            } catch (ExecutionException | InterruptedException e) {
                e.printStackTrace();
            }
        }
        //返回结果
        return result;
    }

}
复制代码

LineTask(单行任务类):

package day06.code_2;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.RecursiveTask;

public class LineTask extends RecursiveTask<Integer> {

    //必备参数
    private static final long serialVersionUID = 1L;

    //行数据
    private String line[];

    //起始、结束位置
    private int start, end;

    //待查找的词汇
    private String word;

    public LineTask(String[] line, int start, int end, String word) {
        this.line = line;
        this.start = start;
        this.end = end;
        this.word = word;
    }

    @Override
    protected Integer compute() {
        //初始化计数器
        int result = 0;
        //若是一行的数据小于100
        if (end - start < 100) {
            //查找指定词汇的数量
            result = count(line, start, end, word);
        } else {
            //分割任务
            int mid = (start + end) / 2;
            LineTask task1 = new LineTask(line, start, mid, word);
            LineTask task2 = new LineTask(line, mid, end, word);
            //执行
            invokeAll(task1, task2);
            //获取子任务的结果
            try {
                result = groupResults(task1.get(), task2.get());
            } catch (ExecutionException | InterruptedException e) {
                e.printStackTrace();
            }
        }
        return result;
    }

    //将子任务结果相加后返回
    private Integer groupResults(Integer number1, Integer number2) {
        return number1 + number2;
    }

    private int count(String[] line, int start, int end, String word) {
        //初始化计数器
        int counter = 0;
        //查找每个元素是否为指定的词汇
        for (int i = start; i < end; i++) {
            if (line[i].equals(word)) {
                counter++;
            }
        }
        //休眠10毫秒
        try {
            Thread.sleep(10);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        //返回结果
        return counter;
    }
}
复制代码

main方法:

package day06.code_2;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;

public class Main {

    public static void main(String[] args) {
        //建立文档生成器
        DocumentMock mock = new DocumentMock();
        //生成文档
        String[][] document = mock.generateDocument(100, 1000, "the");
        //建立文档搜素任务
        DocumentTask task = new DocumentTask(document, 0, 100, "the");
        //建立线程池
        ForkJoinPool pool = new ForkJoinPool();
        //异步执行文档搜索任务
        pool.execute(task);
        //每隔一秒打印一次线程池的状态直到任务执行结束
        do {
            System.out.println("****************************************");
            //并行级别
            System.out.printf("Main: Parallelism: %d\n",
                    pool.getParallelism());
            //正在工做的线程
            System.out.printf("Main: Active Threads: %d\n",
                    pool.getActiveThreadCount());
            //已提交的任务数量(不包括还没有执行的)
            System.out.printf("Main: Task Count: %d\n",
                    pool.getQueuedTaskCount());
            //窃取工做的数量
            System.out.printf("Main: Steal Count: %d\n",
                    pool.getStealCount());
            System.out.println("****************************************");
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } while (!task.isDone());
        //关闭线程池
        pool.shutdown();
        //打印待查找关键词的数量
        try {
            System.out.printf("Main: The word appears %d in the document",
                    task.get());
        } catch (ExecutionException | InterruptedException e) {
            e.printStackTrace();
        }
    }

}
复制代码

3.异步运行任务

当咱们采用异步的方式向线程池发送任务时,方法将当即返回,代码也将继续向下执行,不过咱们提交的任务会继续执行。在第二小节中咱们已经将异步运行任务的相关方法记录了,就不在此赘述

范例实现

在这个范例中咱们将查找指定的文件夹内是否有咱们要查找的文件
FolderProcessor类(文件查找任务类):

package day06.code_3;

import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.RecursiveTask;

public class FolderProcessor extends RecursiveTask<List<String>> {

    //必备参数
    private static final long serialVersionUID = 1L;

    //文件夹路径
    private String path;

    //文件后缀名
    private String extension;

    public FolderProcessor(String path, String extension) {
        this.path = path;
        this.extension = extension;
    }

    @Override
    protected List<String> compute() {
        //建立一个集合用于装载文件路径
        ArrayList<String> list = new ArrayList<>();
        //建立集合用于装载任务
        ArrayList<FolderProcessor> tasks = new ArrayList<>();
        //建立文件对象
        File file = new File(path);
        //获得文件夹下的所有文件
        File[] content = file.listFiles();
        //判断是否为空
        if (content != null) {
            //遍历集合
            for (int i = 0; i < content.length; i++) {
                //若是是文件夹就建立任务继续查找
                if (content[i].isDirectory()) {
                    FolderProcessor task = new FolderProcessor
                            (content[i].getAbsolutePath(), extension);
                    //异步执行任务
                    task.fork();
                    //将任务保存进集合
                    tasks.add(task);
                } else {
                    //检查文件是否符合要求,符合的话就装入集合
                    if (checkFile(content[i].getName())) {
                        list.add(content[i].getAbsolutePath());
                    }
                }
            }
        }
        //若是文件集合容量超过50了就打印
        if (tasks.size() > 50) {
            System.out.printf("%s: %d tasks run\n",
                    file.getAbsolutePath(), tasks.size());
        }
        //整合子任务返回的结果
        addResultsFromTasks(list, tasks);
        //返回结果
        return list;
    }

    private void addResultsFromTasks(List<String> list, List<FolderProcessor> tasks) {
        //遍历任务集合
        for (FolderProcessor item : tasks) {
            //取得全部子任务返回的结果并装进集合中
            list.addAll(item.join());
        }
    }

    //检查文件后缀名是否符合要求
    private boolean checkFile(String name) {
        return name.endsWith(extension);
    }
}
复制代码

main方法:

package day06.code_3;

import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;

public class Main {

    public static void main(String[] args) {
        //建立线程池
        ForkJoinPool pool = new ForkJoinPool();
        //建立三个任务并异步执行
        FolderProcessor system = new FolderProcessor("C:\\", "exe");
        FolderProcessor program = new FolderProcessor("D:\\", "exe");
        FolderProcessor data = new FolderProcessor("F:\\", "exe");
        pool.execute(system);
        pool.execute(program);
        pool.execute(data);
        //在任务没有都结束以前不断循环打印线程池的信息
        do {
            System.out.println("***************************************");
            System.out.printf("Main: Parallelism: %d\n",
                    pool.getParallelism());
            System.out.printf("Main: Active Threads: %d\n",
                    pool.getActiveThreadCount());
            System.out.printf("Main: Task Count: %d\n",
                    pool.getQueuedTaskCount());
            System.out.printf("Main: Steal Count: %d\n",
                    pool.getStealCount());
            System.out.println("***************************************");
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } while ((!system.isDone()) || (!program.isDone()) || (!data.isDone()));
        //关闭线程池
        pool.shutdown();
        //获取并打印每个任务返回的结果
        List<String> result;
        result = system.join();
        System.out.printf("System: %d files found\n", result.size());
        result = program.join();
        System.out.printf("Program: %d files found\n", result.size());
        result = data.join();
        System.out.printf("Data: %d files found\n", result.size());
    }

}
复制代码

4.在任务中抛出异常

ForkJoinTask类的compute()方法中不容许抛出非运行时异常,可是咱们仍能够抛出运行时异常。然而,当任务抛出运行时异常时,ForkJoinPoolForkJoinTask类的行为和咱们期待的并不相同。程序不会结束运行,异常信息也不会打印出来。只有当咱们去获取任务的结果时,异常才会抛出。须要注意的是,当子任务抛出异常时,它的父任务也会受到影响。如下ForkJoinTask类中的几个方法会对咱们获取异常信息有必定帮助:

  1. isCompletedAbnormally():若是主任务或它的子任务抛出了异常,此方法将返回true
  2. isCompletedNormally():若是主任务及它的子任务均正常完成了,此方法返回true
  3. getException():调用此方法来得到任务抛出的异常对象

范例实现

在这个范例中,咱们将对一个数组进行搜索。搜索任务中若是包含了索引3,则抛出运行时异常
Task类:

package day06.code_4;

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

public class Task extends RecursiveTask<Integer> {

    //数组
    private int[] array;

    //起始、终止位置
    private int start, end;

    public Task(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Integer compute() {
        //打印搜索范围的信息
        System.out.printf("Task: Start from %d to %d\n",
                start, end);
        //若是搜索范围小于10
        if (end - start < 10) {
            //判断是否包含索引三
            if ((3 > start) && (3 < end)) {
                //抛出运行时异常
                throw new RuntimeException("This task throws an Exception: " +
                        "Task from " + start + " to " + end);
            }
            //休眠1秒
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            //分割任务
            int mid = (start + end) / 2;
            Task task1 = new Task(array, start, mid);
            Task task2 = new Task(array, mid, end);
            //执行
            invokeAll(task1, task2);
        }
        //打印任务结束语
        System.out.printf("Task: End from %d to %d\n", start, end);
        return 0;

    }
}
复制代码

main方法:

package day06.code_4;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;

public class Main {

    public static void main(String[] args) {
        //建立数组
        int[] array = new int[100];
        //建立任务
        Task task = new Task(array, 0, 100);
        //建立线程池
        ForkJoinPool pool = new ForkJoinPool();
        //执行任务
        pool.execute(task);
        //关闭线程池
        pool.shutdown();
        //休眠,直至线程池中的任务所有完成
        try {
            pool.awaitTermination(1, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        //判断任务是否存在异常
        if (task.isCompletedAbnormally()) {
            //打印异常提示语
            System.out.printf("Main: An exception has ocurred\n");
            //打印获取到的异常对象
            System.out.printf("Main: %s\n", task.getException());
        }
        //打印任务结果
        System.out.printf("Main: Result: %d", task.join());
    }
}
复制代码

5.取消任务

ForkJoinTask类提供了cancel(boolean mayInterruptIfRunning)方法来达到取消任务的目的。和以前咱们用到过的FutureTask类不一样的是,ForkJoinTask类的cancel()方法只能取消未被执行的任务。JavaAPI文档指出,在ForkJoinTask类的默认实现中,传入的参数并无起到做用,这就致使已经开始执行和已经执行结束的任务都不能被取消。取消成功返回true,不然返回false。另外,ForkJoinPool类中并无提供任务用于取消任务的方法。

范例实现

在这个范例中,咱们将在数组中寻找一个数字,找到后就取消其余的搜索任务。
ArrayGenerator(数组生成类):

package day06.code_5;

import java.util.Random;

public class ArrayGenerator {

    public int[] generateArray(int size) {
        //根据传入的参数生成一个数组
        int[] array = new int[size];
        //建立随机数生成器对象
        Random random = new Random();
        //对数组进行初始化
        for (int i = 0; i < size; i++) {
            array[i] = random.nextInt(10);
        }
        //返回数组
        return array;
    }

}
复制代码

TaskManager(任务管理类,该类将帮助咱们取消其余任务):

package day06.code_5;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinTask;

public class TaskManager {

    //任务集合
    private List<ForkJoinTask<Integer>> tasks;

    public TaskManager() {
        tasks = new ArrayList<>();
    }

    //向集合中添加任务
    public void addTask(ForkJoinTask<Integer> task) {
        tasks.add(task);
    }

    public void cancelTasks(ForkJoinTask<Integer> cancelTask) {
        //取消除传入的任务之外的其余全部任务
        for (ForkJoinTask<Integer> task : tasks) {
            if (task != cancelTask) {
                //取消任务
                task.cancel(true);
                //打印取消信息
                ((SearchNumberTask) task).writeCancelMessage();
            }
        }
    }
}
复制代码

SearchNumberTask(搜索数字任务类):

package day06.code_5;


import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

public class SearchNumberTask extends RecursiveTask<Integer> {

    //待搜索的数组
    private int[] numbers;

    //搜索范围
    private int start, end;

    //目标数字
    private int number;

    //任务管理器
    private TaskManager manager;

    //未查询到目标数字时返回的常量
    private static final int NOT_FOUND = -1;

    //必要参数
    private static final long serialVersionUID = 1L;

    public SearchNumberTask(int[] numbers, int start, int end, int number, TaskManager manager) {
        this.numbers = numbers;
        this.start = start;
        this.end = end;
        this.number = number;
        this.manager = manager;
    }

    @Override
    protected Integer compute() {
        //打印任务开始提示信息
        System.out.printf("Task: %d : %d\n", start, end);
        int ret;
        //若是搜索范围大于10
        if (end - start > 10) {
            //调用切割任务的方法
            ret = launchTasks();
        } else {
            //查找目标数字
            ret = lookForNumber();
        }
        //返回结果
        return ret;
    }

    private int launchTasks() {
        //切割任务
        int mid = (start + end) / 2;
        //建立两个新的任务在将其加入任务集合后执行
        SearchNumberTask task1 = new SearchNumberTask(numbers, start, mid, number, manager);
        SearchNumberTask task2 = new SearchNumberTask(numbers, mid, end, number, manager);
        manager.addTask(task1);
        manager.addTask(task2);
        task1.fork();
        task2.fork();
        //返回值
        int returnValue;
        //获取任务1的结果
        returnValue = task1.join();
        //若是查询到了就返回索引
        if (returnValue != -1) {
            return returnValue;
        }
        //不然返回任务2的结果
        return task2.join();

    }

    private int lookForNumber() {
        //遍历搜索范围内的数组
        for (int i = start; i < end; i++) {
            //若是是目标数字
            if (numbers[i] == number) {
                //打印查找成功提示语
                System.out.printf("Task: Number %d found in position %d\n",
                        number, i);
                //调用任务管理器的方法取消其余任务
                manager.cancelTasks(this);
                //返回目标数字的索引
                return i;
            }
            //休眠1秒
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        //没有查询到,返回常量
        return NOT_FOUND;
    }

    public void writeCancelMessage() {
        //打印任务取消的提示信息
        System.out.printf("Task: Cancelled task from %d to %d\n",
                start, end);
    }
}
复制代码

main方法:

package day06.code_5;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;

public class Main {

    public static void main(String[] args) {
        //建立数组生成器
        ArrayGenerator generator = new ArrayGenerator();
        //获得一个容量为1000的数组
        int[] array = generator.generateArray(1000);
        //建立任务管理器
        TaskManager manager = new TaskManager();
        //建立线程池
        ForkJoinPool pool = new ForkJoinPool();
        //建立搜素数字任务
        SearchNumberTask task = new SearchNumberTask
                (array, 0, 1000, 5, manager);
        //将任务发送给线程池执行
        pool.execute(task);
        //关闭线程池
        pool.shutdown();
        //等待线程池将全部未取消的任务执行完毕
        try {
            pool.awaitTermination(1, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        //打印程序结束信息
        System.out.println("Main: The program has finished");
    }

}
复制代码
相关文章
相关标签/搜索