Executor 初探
0x00 Executor
Executor的主要功能如下:
- 无需创建任何的
Thread
对象,如果想执行一个线程的话,仅需要将这个线程任务实例(一个实现了Runnable
接口的类)递交给executor
执行即可,executor
将会自主管理线程并执行任务。 executor
会自动复用线程,类似于数据库连接池一样,其也会创建一个线程池,线程池里面有规划好的预先启动的线程(worker-threads
),新任务到达之后自动分配给线程池中的线程执行,在没有任务时线程池中的线程处于一种等待的状态。这样就避免了每次都要重新创建线程的时间消耗。- 易实现对
executor
内部资源的控制,我们可以在创建executor
时指定内部线程池的线程数量,如果有超出线程数的任务被安排进来了,其会首先进入一个任务队列,executor
中的工作线程完成一项线程任务之后,将会自动从这个任务队列中取出任务来执行。 - 必需使用代码显式结束
executor
的运行。这里一般是指调用executor
的shutdown
方法来结束线程池的运行。否则线程池中的线程将会处于一直等待的状态,在Java中主线程会等待所有的子线程结束之后才会自动结束,如果不手动调用shutdown
方法来结束线程池的运行,那么到最后主线程是不会主动退出的,会一直执行下去。所以一定不要忘了手动结束线程池的运行。
0x01 kNN算法(单线程串行版)
kNN算法是一个机器学习中的分类算法,为了简要地实现一下多线程编程,我们首先使用单线程来实现一个kNN算法。
kNN算法的基本流程可参考我写过的另一篇文章:kNN
我们这里使用的数据集为Bank Marketing Data Set,这是一家银行通过不断地电话采访获得的数据,其中包含被采访人的一些个人信息(比如年龄、工作以及婚姻状况等)以及一些银行账户信息(比如账户余额、是否负债、是否有信用卡违约记录等),我们需要使用这些信息来预测这个人是否是银行定期存款(term deposit)的潜在客户。
我们现在旨在学习Java并发相关知识,所以对于源数据我们直接用,不做过多但是必要的处理。同时对于距离的计算,我们简化一下其流程,我们直接把每个字段读入,也不管他是什么类型了,我们就直接取其字符串中每个字符的ASCII码加和的值作为其数值化之后的结果,然后计算其欧几里得距离。这样下来,原始40000行数据依旧会产生较大的计算量,可以明显地分辨出并行化前和并行化后的运行时间的差距来。但是,机器学习里有一句著名的话叫做garbage in garbage out,这样下来预测的精确度和投硬币的概率差不多,因为我们输出的是bool
值,这样下来预测的精确度也在50%左右,所以就是投硬币嘛!
过多但是必要的处理指的是使用随机森林填充原始数据集中标记为unknown的数据字段。这个过程是必要的,因为其中的default字段(
bool
值表示是否有信用卡违约记录)有超过8000行的缺失。
首先我们直接给出其单线程串行版的算法,代码如下:
1 | public boolean serialPredict(ArrayList<String> input) { |
我们使用如下代码来统计算法耗时:
1 | KNN knn = new KNN("TRAIN SET FILE PATH HERE", 10); |
单线程版的算法耗时为142s.
0x02 kNN算法(多线程并行版)
多线程这个算法,最简单的办法就是将原始数据集进行分组,每个线程执行一个分组内的任务,最后等待所有的线程全部结束之后,将所有线程产生的距离进行进行整合和排序。
由于所有的线程都是对原始数据集进行只读操作,并不进行写入或修改操作,那我们就没有必要对原始数据集的访问进行加锁,这样还可以提高程序的运行效率。同时我们还可以开辟一个distance
数组,用以记录输入数据对已知数据的距离信息,我们在进行多线程处理的时候,可以对这个数组进行分片,每个线程只能访问这个数组中的一个区间,且保证所有线程访问的区间并不重叠,而且我们保证对数组整体的排序和读取是在所有子线程全部完成之后才开始进行的,这样并发对这个数组写入也不需要加锁。
此外,将所有的距离计算任务平均划分给多个线程进行并行计算,而对距离的排序则需要等到所有线程的计算任务完成之后再统一排序,这就需要引入CountDownLatch
倒计时器,这个倒计时器中有2个核心方法,一个是countDown()
方法,用于对倒计时减一。另一个是await()
方法,用于阻塞一个线程直到倒计时器中的倒计时归零为止。CountDownLatch
在构造的时候需要指定一个初始化的倒计时值,假设我们的数据集中有40000组数据,我们就可以将这个CountDownLatch
的初始倒计时值设置为40000,每完成对一组数据的距离值计算,然后调用countDown()
方法对这个倒计时值减一。在主线程中使用await()
方法进行等待,当计时器的值归零后,对所有距离值进行排序,然后投票出最终的预测值。
由此,我们可以有如下代码:
1 | public boolean parallelPredict(ArrayList<String> input) { |
多线程版性能得到了大幅度的提升,大约81s就完成了所有的计算。性能提升42.95%由此可见还是非常明显的。
全部的代码可以参加我的GitHub.