Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文将为你们展示Alink如何划分训练数据集和测试数据集。java
两分法算法
通常作预测分析时,会将数据分为两大部分。一部分是训练数据,用于构建模型,一部分是测试数据,用于检验模型。数组
三分法dom
但有时候模型的构建过程当中也须要检验模型/辅助模型构建,这时会将训练数据再分为两个部分:1)训练数据;2)验证数据(Validation Data)。因此这种状况下会把数据分为三部分。机器学习
Training set是用来训练模型或肯定模型参数的,如ANN中权值等;ide
Validation set是用来作模型选择(model selection),即作模型的最终优化及肯定,如ANN的结构;函数
Test set则纯粹是为了测试已经训练好的模型的推广能力。固然test set并不能保证模型的正确性,他只是说类似的数据用此模型会得出类似的结果。学习
实际应用测试
实际应用中,通常只将数据集分红两类,即training set 和test set,大多数文章并不涉及validation set。咱们这里也不涉及。你们经常使用的sklearn的train_test_split函数就是将矩阵随机划分为训练子集和测试子集,并返回划分好的训练集测试集样本和训练集测试集标签。优化
首先咱们给出示例代码,而后会深刻剖析:
public class SplitExample { public static void main(String[] args) throws Exception { String url = "iris.csv"; String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string"; //这里是批处理 BatchOperator data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema); SplitBatchOp spliter = new SplitBatchOp().setFraction(0.8); spliter.linkFrom(data); BatchOperator trainData = spliter; BatchOperator testData = spliter.getSideOutput(0); // 这里是流处理 CsvSourceStreamOp dataS = new CsvSourceStreamOp().setFilePath(url).setSchemaStr(schema); SplitStreamOp spliterS = new SplitStreamOp().setFraction(0.4); spliterS.linkFrom(dataS); StreamOperator train_data = spliterS; StreamOperator test_data = spliterS.getSideOutput(0); } }
SplitBatchOp是分割批处理的主要类,具体构建DAG的工做是在其linkFrom完成的。
整体思路比较简单:
numTarget = totCount * fraction
task_n_count * fraction
totSelect = task_1_count * fraction + task_2_count * fraction + ... task_n_count * fraction
numTarget - totSelect
加入到某一个task中。若是要分割数据,首先必须知道数据集的记录数。好比这个DataSet的记录是1万个?仍是十万个?由于数据集可能会很大,因此这一步操做也使用了并行处理,即把数据分区,而后经过mapPartition操做获得每个分区上元素的数目。
DataSet<Tuple2<Integer, Long>> countsPerPartition = DataSetUtils.countElementsPerPartition(rows); //返回哪一个task有哪些记录数 DataSet<long[]> numPickedPerPartition = countsPerPartition .mapPartition(new CountInPartition(fraction)) //计算总数 .setParallelism(1) .name("decide_count_of_each_partition");
由于每一个分区就对应了一个task,因此咱们也能够认为,这是获取了每一个task的记录数。
具体工做是在 DataSetUtils.countElementsPerPartition 中完成的。返回类型是<index of this subtask, record count in this subtask>,好比3号task拥有30个记录。
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) { return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() { @Override public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception { long counter = 0; for (T value : values) { counter++; //计算本task的记录总数 } out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter)); } }); }
计算总数的工做实际上是在下一阶段算子中完成的。
接下来的工做主要是在 CountInPartition.mapPartition 完成的,其做用是随机决定每一个task选择多少个记录。
这时候就不须要并行了,因此 .setParallelism(1)
获得了每一个分区记录数以后,咱们遍历每一个task的记录数,而后累积获得总记录数 totCount(就是从上而下计算出来的总数)。
public void mapPartition(Iterable<Tuple2<Integer, Long>> values, Collector<long[]> out) throws Exception { long totCount = 0L; List<Tuple2<Integer, Long>> buffer = new ArrayList<>(); for (Tuple2<Integer, Long> value : values) { //遍历输入的全部分区记录 totCount += value.f1; //f1是Long类型的记录数 buffer.add(value); } ... //后续代码在下面分析。 }
而后CountInPartition.mapPartition函数中会随机决定每一个task会选择的记录数。mapPartition的参数 Iterable<Tuple2<Integer, Long>> values 就是前一阶段的结果 :一个元祖<task id, 每一个task的记录数目>。
把这些元祖结合在一块儿,记录在buffer这个列表中。
buffer = {ArrayList@8972} size = 4 0 = {Tuple2@8975} "(3,38)" // 3号task,其对应的partition记录数是38个。 1 = {Tuple2@8976} "(2,0)" 2 = {Tuple2@8977} "(0,38)" 3 = {Tuple2@8978} "(1,74)"
系统的task数目就是buffer大小。
int npart = buffer.size(); // num tasks
而后,根据”记录总数“计算出来 “随机训练数据的个数numTarget”。好比总数1万,应该随机分配20%,因而numTarget就应该是2千。这个数字之后会用到。
long numTarget = Math.round((totCount * fraction));
获得每一个task的记录数目,好比是上面buffer中的 38,0,38,仍是74,记录在 eachCount 中。
for (Tuple2<Integer, Long> value : buffer) { eachCount[value.f0] = value.f1; }
获得每一个task中随机选中的训练记录数,记录在 eachSelect 中。就是每一个task目前 “记录数字 * fraction”。好比3号task记录数是38个,应该选20%,则38*20%=8个。
而后把这些task本身的“随机训练记录数”再累加起来获得 totSelect(就是从下而上计算出来的总数)。
long totSelect = 0L; for (int i = 0; i < npart; i++) { eachSelect[i] = Math.round(Math.floor(eachCount[i] * fraction)); totSelect += eachSelect[i]; }
请注意,这时候 totSelect 和 以前计算的numTarget就有具体细微出入了,就是理论上的一个数字,可是咱们 从上而下 计算 和 从下而上 计算,其结果可能不同。经过下面咱们能够看出来。
numTarget = all count * fraction totSelect = task_1_count * fraction + task_2_count * fraction + ...
因此咱们下一步要处理这个细微出入,就获得remain,这是"整体算出来的随机数目" numTarget 和 "从全部task选中的随机训练记录数累积" totSelect 的差。
if (totSelect < numTarget) { long remain = numTarget - totSelect; remain = Math.min(remain, totCount - totSelect);
若是恰好个数相等,则就正常分配。
if (remain == totCount - totSelect) {
若是数目不等,随机决定把"多出来的remain"加入到eachSelect数组中的随便一个记录上。
for (int i = 0; i < Math.min(remain, npart); i++) { int taskId = shuffle.get(i); while (eachSelect[taskId] >= eachCount[taskId]) { taskId = (taskId + 1) % npart; } eachSelect[taskId]++; }
最后给出全部信息
long[] statistics = new long[npart * 2]; for (int i = 0; i < npart; i++) { statistics[i] = eachCount[i]; statistics[i + npart] = eachSelect[i]; } out.collect(statistics); // 咱们这里是4核,因此前面四项是eachCount,后面是eachSelect statistics = {long[8]@9003} 0 = 38 //eachCount 1 = 38 2 = 36 3 = 38 4 = 31 //eachSelect 5 = 31 6 = 28 7 = 30
这些信息是做为广播变量存储起来的,立刻下面就会用到。
.withBroadcastSet(numPickedPerPartition, "counts")
CountInPartition.PickInPartition函数中会随机在每一个task选择记录。
首先获得task数目 和 以前存储的广播变量(就是以前刚刚存储的)。
int npart = getRuntimeContext().getNumberOfParallelSubtasks(); List<long[]> bc = getRuntimeContext().getBroadcastVariable("counts");
分离count和select。
long[] eachCount = Arrays.copyOfRange(bc.get(0), 0, npart); long[] eachSelect = Arrays.copyOfRange(bc.get(0), npart, npart * 2);
获得总task数目
int taskId = getRuntimeContext().getIndexOfThisSubtask();
获得本身 task 对应的 count, select
long count = eachCount[taskId]; long select = eachSelect[taskId];
添加本task对应的记录,随机洗牌打乱顺序
for (int i = 0; i < count; i++) { shuffle.add(i); //就是把count内的数字加到数组 } Collections.shuffle(shuffle, new Random(taskId)); //洗牌打乱顺序 // suffle举例 shuffle = {ArrayList@8987} size = 38 0 = {Integer@8994} 17 1 = {Integer@8995} 8 2 = {Integer@8996} 33 3 = {Integer@8997} 34 4 = {Integer@8998} 20 5 = {Integer@8999} 0 6 = {Integer@9000} 26 7 = {Integer@9001} 27 8 = {Integer@9002} 23 9 = {Integer@9003} 28 10 = {Integer@9004} 9 11 = {Integer@9005} 16 12 = {Integer@9006} 13 13 = {Integer@9007} 2 14 = {Integer@9008} 5 15 = {Integer@9009} 31 16 = {Integer@9010} 15 17 = {Integer@9011} 22 18 = {Integer@9012} 18 19 = {Integer@9013} 35 20 = {Integer@9014} 36 21 = {Integer@9015} 12 22 = {Integer@9016} 7 23 = {Integer@9017} 21 24 = {Integer@9018} 14 25 = {Integer@9019} 1 26 = {Integer@9020} 10 27 = {Integer@9021} 30 28 = {Integer@9022} 29 29 = {Integer@9023} 19 30 = {Integer@9024} 25 31 = {Integer@9025} 32 32 = {Integer@9026} 37 33 = {Integer@9027} 4 34 = {Integer@9028} 11 35 = {Integer@9029} 6 36 = {Integer@9030} 3 37 = {Integer@9031} 24
随机选择,把选择后的再排序回来
for (int i = 0; i < select; i++) { selected[i] = shuffle.get(i); //这时候select看起来是按照顺序选择,可是实际上suffle里面已是乱序 } Arrays.sort(selected); //此次再排序 // selected举例,一共30个 selected = {int[30]@8991} 0 = 0 1 = 1 2 = 2 3 = 5 4 = 7 5 = 8 6 = 9 7 = 10 8 = 12 9 = 13 10 = 14 11 = 15 12 = 16 13 = 17 14 = 18 15 = 19 16 = 20 17 = 21 18 = 22 19 = 23 20 = 26 21 = 27 22 = 28 23 = 29 24 = 30 25 = 31 26 = 33 27 = 34 28 = 35 29 = 36
发送选择的数据
if (numEmits < selected.length && iRow == selected[numEmits]) { out.collect(row); numEmits++; }
output是训练数据集,SideOutput是测试数据集。由于这两个数据集在Alink内部都是Table类型,因此直接使用了SQL算子 minusAll
来完成分割。
this.setOutput(out, in.getSchema()); this.setSideOutputTables(new Table[]{in.getOutputTable().minusAll(this.getOutputTable())});
训练是在SplitStreamOp类完成的,其经过linkFrom完成了模型的构建。
流处理依赖SplitStream 和 SelectTransformation 这两个类来完成分割流。具体并无创建一个物理操做,而只是影响了上游算子如何与下游算子联系,如何选择记录。
SplitStream <Row> splited = in.getDataStream().split(new RandomSelectorOp(getFraction()));
首先,用RandomSelectorOp来随机决定输出时候选择哪一个流。咱们能够看到,这里就是随便起了"a", "b" 这两个名字而已。
class RandomSelectorOp implements OutputSelector <Row> { private double fraction; private Random random = null; @Override public Iterable <String> select(Row value) { if (null == random) { random = new Random(System.currentTimeMillis()); } List <String> output = new ArrayList <String>(1); output.add((random.nextDouble() < fraction ? "a" : "b")); //随机选取数字分配,随意起的名字 return output; } }
其次,获得那两个随机生成的流。
DataStream <Row> partA = splited.select("a"); DataStream <Row> partB = splited.select("b");
最后把这两个流分别设置为output和sideOutput。
this.setOutput(partA, in.getSchema()); //训练集 this.setSideOutputTables(new Table[]{ DataStreamConversionUtil.toTable(getMLEnvironmentId(), partB, in.getSchema())}); //验证集
最后返回自己,这时候SplitStreamOp拥有两个成员变量:
this.output就是训练集。
this.sideOutPut就是验证集。
return this;