使用基于Apache Spark的随机森林方法预测贷款风险

 

转载 2016年07月20日 14:19:56
http://blog.csdn.net/mr__fang/article/details/51967852

原文:Predicting Loan Credit Risk using Apache Spark Machine Learning Random Forests 
做者:Carol McDonald,MapR解决方案架构师 
翻译:KK4SBB 
责编:周建丁(zhoujd@csdn.net)javascript

在本文中,我将向你们介绍如何使用Apache Spark的spark.ml库中的随机森林算法来对银行信用贷款的风险作分类预测。Spark的spark.ml库基于DataFrame,它提供了大量的接口,帮助用户建立和调优机器学习工做流。结合dataframe使用spark.ml,可以实现模型的智能优化,从而提高模型效果。html

分类算法

分类算法是一类监督式机器学习算法,它根据已知标签的样本(如已经明确交易是否存在欺诈)来预测其它样本所属的类别(如是否属于欺诈性的交易)。分类问题须要一个已经标记过的数据集和预先设计好的特征,而后基于这些信息来学习给新样本打标签。所谓的特征便是一些“是与否”的问题。标签就是这些问题的答案。在下面这个例子里,若是某个动物的行走姿态、游泳姿式和叫声都像鸭子,那么就给它打上“鸭子”的标签。java

咱们来看一个银行信贷的信用风险例子:git

  • 咱们须要预测什么? 
    • 某我的是否会按时还款
    • 这就是标签:此人的信用度
  • 你用来预测的“是与否”问题或者属性是什么? 
    • 申请人的基本信息和社会身份信息:职业,年龄,存款储蓄,婚姻状态等等……
    • 这些就是特征,用来构建一个分类模型,你从中提取出对分类有帮助的特征信息。

决策树模型

决策树是一种基于输入特征来预测类别或是标签的分类模型。决策树的工做原理是这样的,它在每一个节点都须要计算特征在该节点的表达式值,而后基于运算结果选择一个分支通往下一个节点。下图展现了一种用来预测信用风险的决策树模型。每一个决策问题就是模型的一个节点,“是”或者“否”的答案是通往子节点的分支。github

  • 问题1:帐户余额是否大于200元? 
    • 问题2:当前就任时间是否超过1年? 
      • 不可信赖

图片描述

随机森林模型

融合学习算法结合了多个机器学习的算法,从而获得了效果更好的模型。随机森林是分类和回归问题中一类经常使用的融合学习方法。此算法基于训练数据的不一样子集构建多棵决策树,组合成一个新的模型。预测结果是全部决策树输出的组合,这样可以减小波动,而且提升预测的准确度。对于随机森林分类模型,每棵树的预测结果都视为一张投票。得到投票数最多的类别就是预测的类别。算法

图片描述

基于Spark机器学习工具来分析信用风险问题

咱们使用德国人信用度数据集,它按照一系列特征属性将人分为信用风险好和坏两类。咱们能够得到每一个银行贷款申请者的如下信息:sql

图片描述

存放德国人信用数据的csv文件格式以下:shell

1,1,18,4,2,1049,1,2,4,2,1,4,2,21,3,1,1,3,1,1,1
1,1,9,4,0,2799,1,3,2,3,1,2,1,36,3,1,2,3,2,1,1
1,2,12,2,9,841,2,4,2,2,1,4,1,23,3,1,1,2,1,1,1

在这个背景下,咱们会构建一个由决策树组成的随机森林模型来预测是否守信用的标签/类别,基于如下特征:apache

  • 标签 -> 守信用或者不守信用(1或者0)
  • 特征 -> {存款余额,信用历史,贷款目的等等}

软件

本教程将使用Spark 1.6.1数组

按照教程指示,登陆MapR沙箱,用户名为user01,密码为mapr。将样本数据文件复制到你的沙箱主目录下/user/user01 using scp。(注意,你可能须要先更新Spark的版本)打开spark shell:

$spark-shell --master local[1]

加载并解析csv数据文件

首先,咱们须要引入机器学习相关的包。

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.VectorAssembler
import sqlContext.implicits._
import sqlContext._
import org.apache.spark.ml.tuning.{ ParamGridBuilder, CrossValidator }
import org.apache.spark.ml.{ Pipeline, PipelineStage }

咱们用一个Scala的case类来定义Credit的属性,对应于csv文件中的一行。

// define the Credit Schema case class Credit( creditability: Double, balance: Double, duration: Double, history: Double, purpose: Double, amount: Double, savings: Double, employment: Double, instPercent: Double, sexMarried: Double, guarantors: Double, residenceDuration: Double, assets: Double, age: Double, concCredit: Double, apartment: Double, credits: Double, occupation: Double, dependents: Double, hasPhone: Double, foreign: Double )

下面的函数解析一行数据文件,将值存入Credit类中。类别的索引值减去了1,所以起始索引值为0.

 // function to create a Credit class from an Array of Double def parseCredit(line: Array[Double]): Credit = { Credit( line(0), line(1) - 1, line(2), line(3), line(4) , line(5), line(6) - 1, line(7) - 1, line(8), line(9) - 1, line(10) - 1, line(11) - 1, line(12) - 1, line(13), line(14) - 1, line(15) - 1, line(16) - 1, line(17) - 1, line(18) - 1, line(19) - 1, line(20) - 1 ) }  // function to transform an RDD of Strings into an RDD of Double def parseRDD(rdd: RDD[String]): RDD[Array[Double]] = { rdd.map(_.split(",")).map(_.map(_.toDouble)) }

接下去,咱们导入germancredit.csv文件中的数据,存为一个String类型的RDD。而后咱们对RDD作map操做,将RDD中的每一个字符串通过ParseRDDR函数的映射,转换为一个Double类型的数组。紧接着是另外一个map操做,使用ParseCredit函数,将每一个Double类型的RDD转换为Credit对象。toDF()函数将Array[[Credit]]类型的RDD转为一个Credit类的Dataframe。

// load the data into a RDD val creditDF= parseRDD(sc.textFile("germancredit.csv")).map(parseCredit).toDF().cache() creditDF.registerTempTable("credit") DataFrame的printSchema()函数将各个字段含义以树状的形式打印到控制台输出。 // Return the schema of this DataFrame creditDF.printSchema root |-- creditability: double (nullable = false) |-- balance: double (nullable = false) |-- duration: double (nullable = false) |-- history: double (nullable = false) |-- purpose: double (nullable = false) |-- amount: double (nullable = false) |-- savings: double (nullable = false) |-- employment: double (nullable = false) |-- instPercent: double (nullable = false) |-- sexMarried: double (nullable = false) |-- guarantors: double (nullable = false) |-- residenceDuration: double (nullable = false) |-- assets: double (nullable = false) |-- age: double (nullable = false) |-- concCredit: double (nullable = false) |-- apartment: double (nullable = false) |-- credits: double (nullable = false) |-- occupation: double (nullable = false) |-- dependents: double (nullable = false) |-- hasPhone: double (nullable = false) |-- foreign: double (nullable = false) // Display the top 20 rows of DataFrame creditDF.show +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+ |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+ | 1.0| 0.0| 18.0| 4.0| 2.0|1049.0| 0.0| 1.0| 4.0| 1.0| 0.0| 3.0| 1.0|21.0| 2.0| 0.0| 0.0| 2.0| 0.0| 0.0| 0.0| | 1.0| 0.0| 9.0| 4.0| 0.0|2799.0| 0.0| 2.0| 2.0| 2.0| 0.0| 1.0| 0.0|36.0| 2.0| 0.0| 1.0| 2.0| 1.0| 0.0| 0.0| | 1.0| 1.0| 12.0| 2.0| 9.0| 841.0| 1.0| 3.0| 2.0| 1.0| 0.0| 3.0| 0.0|23.0| 2.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.0| | 1.0| 0.0| 12.0| 4.0| 0.0|2122.0| 0.0| 2.0| 3.0| 2.0| 0.0| 1.0| 0.0|39.0| 2.0| 0.0| 1.0| 1.0| 1.0| 0.0| 1.0| | 1.0| 0.0| 12.0| 4.0| 0.0|2171.0| 0.0| 2.0| 4.0| 2.0| 0.0| 3.0| 1.0|38.0| 0.0| 1.0| 1.0| 1.0| 0.0| 0.0| 1.0| | 1.0| 0.0| 10.0| 4.0| 0.0|2241.0| 0.0| 1.0| 1.0| 2.0| 0.0| 2.0| 0.0|48.0| 2.0| 0.0| 1.0| 1.0| 1.0| 0.0| 1.0| | 1.0| 0.0| 8.0| 4.0| 0.0|3398.0| 0.0| 3.0| 1.0| 2.0| 0.0| 3.0| 0.0|39.0| 2.0| 1.0| 1.0| 1.0| 0.0| 0.0| 1.0| | 1.0| 0.0| 6.0| 4.0| 0.0|1361.0| 0.0| 1.0| 2.0| 2.0| 0.0| 3.0| 0.0|40.0| 2.0| 1.0| 0.0| 1.0| 1.0| 0.0| 1.0| | 1.0| 3.0| 18.0| 4.0| 3.0|1098.0| 0.0| 0.0| 4.0| 1.0| 0.0| 3.0| 2.0|65.0| 2.0| 1.0| 1.0| 0.0| 0.0| 0.0| 0.0| | 1.0| 1.0| 24.0| 2.0| 3.0|3758.0| 2.0| 0.0| 1.0| 1.0| 0.0| 3.0| 3.0|23.0| 2.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| | 1.0| 0.0| 11.0| 4.0| 0.0|3905.0| 0.0| 2.0| 2.0| 2.0| 0.0| 1.0| 0.0|36.0| 2.0| 0.0| 1.0| 2.0| 1.0| 0.0| 0.0| | 1.0| 0.0| 30.0| 4.0| 1.0|6187.0| 1.0| 3.0| 1.0| 3.0| 0.0| 3.0| 2.0|24.0| 2.0| 0.0| 1.0| 2.0| 0.0| 0.0| 0.0| | 1.0| 0.0| 6.0| 4.0| 3.0|1957.0| 0.0| 3.0| 1.0| 1.0| 0.0| 3.0| 2.0|31.0| 2.0| 1.0| 0.0| 2.0| 0.0| 0.0| 0.0| | 1.0| 1.0| 48.0| 3.0| 10.0|7582.0| 1.0| 0.0| 2.0| 2.0| 0.0| 3.0| 3.0|31.0| 2.0| 1.0| 0.0| 3.0| 0.0| 1.0| 0.0| | 1.0| 0.0| 18.0| 2.0| 3.0|1936.0| 4.0| 3.0| 2.0| 3.0| 0.0| 3.0| 2.0|23.0| 2.0| 0.0| 1.0| 1.0| 0.0| 0.0| 0.0| | 1.0| 0.0| 6.0| 2.0| 3.0|2647.0| 2.0| 2.0| 2.0| 2.0| 0.0| 2.0| 0.0|44.0| 2.0| 0.0| 0.0| 2.0| 1.0| 0.0| 0.0| | 1.0| 0.0| 11.0| 4.0| 0.0|3939.0| 0.0| 2.0| 1.0| 2.0| 0.0| 1.0| 0.0|40.0| 2.0| 1.0| 1.0| 1.0| 1.0| 0.0| 0.0| | 1.0| 1.0| 18.0| 2.0| 3.0|3213.0| 2.0| 1.0| 1.0| 3.0| 0.0| 2.0| 0.0|25.0| 2.0| 0.0| 0.0| 2.0| 0.0| 0.0| 0.0| | 1.0| 1.0| 36.0| 4.0| 3.0|2337.0| 0.0| 4.0| 4.0| 2.0| 0.0| 3.0| 0.0|36.0| 2.0| 1.0| 0.0| 2.0| 0.0| 0.0| 0.0| | 1.0| 3.0| 11.0| 4.0| 0.0|7228.0| 0.0| 2.0| 1.0| 2.0| 0.0| 3.0| 1.0|39.0| 2.0| 1.0| 1.0| 1.0| 0.0| 0.0| 0.0| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+ 

dataframe初始化以后,你能够用SQL命令查询数据了。下面是一些使用Scala DataFrame接口查询数据的例子:

计算数值型数据的统计信息,包括计数、均值、标准差、最小值和最大值。

 // computes statistics for balance creditDF.describe("balance").show +-------+-----------------+ |summary| balance| +-------+-----------------+ | count| 1000| | mean| 1.577| | stddev|1.257637727110893| | min| 0.0| | max| 3.0| +-------+-----------------+  // compute the avg balance by creditability (the label) creditDF.groupBy("creditability").avg("balance").show +-------------+------------------+ |creditability| avg(balance)| +-------------+------------------+ | 1.0|1.8657142857142857| | 0.0|0.9033333333333333| +-------------+------------------+

你能够用某个表名将DataFrame注册为一张临时表,而后用SQLContext提供的sql方法执行SQL命令。下面是几个用sqlContext查询的例子:

sqlContext.sql("SELECT creditability, avg(balance) as avgbalance, avg(amount) as avgamt, avg(duration) as avgdur FROM credit GROUP BY creditability ").show +-------------+------------------+------------------+------------------+ |creditability| avgbalance| avgamt| avgdur| +-------------+------------------+------------------+------------------+ | 1.0|1.8657142857142857| 2985.442857142857|19.207142857142856| | 0.0|0.9033333333333333|3938.1266666666666| 24.86| +-------------+------------------+------------------+------------------+

提取特征

为了构建一个分类模型,你首先须要提取对分类最有帮助的特征。在德国人信用度的数据集里,每条样本用两个类别来标记——1(可信)和0(不可信)。

每一个样本的特征包括如下的字段:

  • 标签 -> 是否可信:0或者1
  • 特征 -> {“存款”,“期限”,“历史记录”,“目的”,“数额”,“储蓄”,“是否在职”,“婚姻”,“担保人”,“居住时间”,“资产”,“年龄”,“历史信用”,“居住公寓”,“贷款”,“职业”,“监护人”,“是否有电话”,“外籍”}

定义特征数组

图片描述

图片来自:学习Spark

 

为了在机器学习算法中使用这些特征,这些特征通过了变换,存入特征向量中,即一组表示各个维度特征值的数值向量。

下图中,用VectorAssembler方法将每一个维度的特征都作变换,返回一个新的dataframe。

//define the feature columns to put in the feature vector
    val featureCols = Array("balance", "duration", "history", "purpose", "amount", "savings", "employment", "instPercent", "sexMarried", "guarantors", "residenceDuration", "assets", "age", "concCredit", "apartment", "credits", "occupation", "dependents", "hasPhone", "foreign" ) //set the input and output column names val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features") //return a dataframe with all of the feature columns in a vector column val df2 = assembler.transform( creditDF) // the transform method produced a new column: features. df2.show +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+ |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign| features| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+ | 1.0| 0.0| 18.0| 4.0| 2.0|1049.0| 0.0| 1.0| 4.0| 1.0| 0.0| 3.0| 1.0|21.0| 2.0| 0.0| 0.0| 2.0| 0.0| 0.0| 0.0|(20,[1,2,3,4,6,7,...|

接着,咱们使用StringIndexer方法返回一个Dataframe,增长了信用度这一列做为标签。

//  Create a label column with the StringIndexer val labelIndexer = new StringIndexer().setInputCol("creditability").setOutputCol("label") val df3 = labelIndexer.fit(df2).transform(df2) // the transform method produced a new column: label. df3.show +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+ |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign| features|label| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+ | 1.0| 0.0| 18.0| 4.0| 2.0|1049.0| 0.0| 1.0| 4.0| 1.0| 0.0| 3.0| 1.0|21.0| 2.0| 0.0| 0.0| 2.0| 0.0| 0.0| 0.0|(20,[1,2,3,4,6,7,...| 0.0|

下图中,数据集被分为训练数据和测试数据两个部分,70%的数据用来训练模型,30%的数据用来测试模型。

// split the dataframe into training and test data val splitSeed = 5043 val Array(trainingData, testData) = df3.randomSplit(Array(0.7, 0.3), splitSeed)

训练模型

图片描述

接着,咱们按照下列参数训练一个随机森林分类器:

  • maxDepth:每棵树的最大深度。增长树的深度能够提升模型的效果,可是会延长训练时间。
  • maxBins:连续特征离散化时选用的最大分桶个数,而且决定每一个节点如何分裂。
  • impurity:计算信息增益的指标
  • auto:在每一个节点分裂时是否自动选择参与的特征个数
  • seed:随机数生成种子

模型的训练过程就是将输入特征和这些特征对应的样本标签相关联的过程。

// create the classifier,  set parameters for training val classifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(3).setNumTrees(20).setFeatureSubsetStrategy("auto").setSeed(5043) // use the random forest classifier to train (fit) the model val model = classifier.fit(trainingData) // print out the random forest trees model.toDebugString res20: String = res5: String = "RandomForestClassificationModel (uid=rfc_6c4ceb92ba78) with 20 trees Tree 0 (weight 1.0): If (feature 0 <= 1.0) If (feature 10 <= 0.0) If (feature 3 <= 6.0) Predict: 0.0 Else (feature 3 > 6.0) Predict: 0.0 Else (feature 10 > 0.0) If (feature 12 <= 63.0) Predict: 0.0 Else (feature 12 > 63.0) Predict: 0.0 Else (feature 0 > 1.0) If (feature 13 <= 1.0) If (feature 3 <= 3.0) Predict: 0.0 Else (feature 3 > 3.0) Predict: 1.0 Else (feature 13 > 1.0) If (feature 7 <= 1.0) Predict: 0.0 Else (feature 7 > 1.0) Predict: 0.0 Tree 1 (weight 1.0): If (feature 2 <= 1.0) If (feature 15 <= 0.0) If (feature 11 <= 0.0) Predict: 0.0 Else (feature 11 > 0.0) Predict: 1.0 Else (feature 15 > 0.0) If (feature 11 <= 0.0) Predict: 0.0 Else (feature 11 > 0.0) Predict: 1.0 Else (feature 2 > 1.0) If (feature 12 <= 31.0) If (feature 5 <= 0.0) Predict: 0.0 Else (feature 5 > 0.0) Predict: 0.0 Else (feature 12 > 31.0) If (feature 3 <= 4.0) Predict: 0.0 Else (feature 3 > 4.0) Predict: 0.0 Tree 2 (weight 1.0): If (feature 8 <= 1.0) If (feature 6 <= 2.0) If (feature 4 <= 10875.0) Predict: 0.0 Else (feature 4 > 10875.0) Predict: 1.0 Else (feature 6 > 2.0) If (feature 1 <= 36.0) Predict: 0.0 Else (feature 1 > 36.0) Predict: 1.0 Else (feature 8 > 1.0) If (feature 5 <= 0.0) If (feature 4 <= 4113.0) Predict: 0.0 Else (feature 4 > 4113.0) Predict: 1.0 Else (feature 5 > 0.0) If (feature 11 <= 2.0) Predict: 0.0 Else (feature 11 > 2.0) Predict: 0.0 Tree 3 ...

测试模型

接下来,咱们对测试数据进行预测。

// run the  model on test features to get predictions
    val predictions = model.transform(testData) 
    //As you can see, the previous model transform produced a new columns: rawPrediction, probablity and prediction.
    predictions.show

    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+
    |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|label|       rawPrediction|         probability|prediction|
    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+
    |          0.0| 0.0| 12.0| 0.0| 5.0|1108.0| 0.0| 3.0| 4.0| 2.0| 0.0| 2.0| 0.0|28.0| 2.0| 1.0| 1.0| 2.0| 0.0| 0.0| 0.0|(20,[1,3,4,6,7,8,...| 1.0|[14.1964586927573...|[0.70982293463786...| 0.0|

而后,咱们用BinaryClassificationEvaluator评估预测的效果,它将预测结果与样本的实际标签相比较,返回一个准确度指标(ROC曲线所覆盖的面积)。本例子中,AUC达到78%。

// create an Evaluator for binary classification, which expects two input columns: rawPrediction and label. val evaluator = new BinaryClassificationEvaluator().setLabelCol("label") // Evaluates predictions and returns a scalar metric areaUnderROC(larger is better). val accuracy = evaluator.evaluate(predictions) accuracy: Double = 0.7824906081835722

使用机器学习管道

咱们接着用管道来训练模型,可能会取得更好的效果。管道采起了一种简单的方式来比较各类不一样组合的参数的效果,这个方法称为网格搜索法(grid search),你先设置好待测试的参数,MLLib就会自动完成这些参数的不一样组合。管道搭建了一条工做流,一次性完成了整个模型的调优,而不是独立对每一个参数进行调优。

下面咱们就用ParamGridBuilder工具来构建参数网格。

// We use a ParamGridBuilder to construct a grid of parameters to search over val paramGrid = new ParamGridBuilder() .addGrid(classifier.maxBins, Array(25, 28, 31)) .addGrid(classifier.maxDepth, Array(4, 6, 8)) .addGrid(classifier.impurity, Array("entropy", "gini")) .build()

建立并完成一条管道。一条管道由一系列stage组成,每一个stage至关于一个Estimator或是Transformer。

val steps: Array[PipelineStage] = Array(classifier) val pipeline = new Pipeline().setStages(steps)

咱们用CrossValidator类来完成模型筛选。CrossValidator类使用一个Estimator类,一组ParamMaps类和一个Evaluator类。注意,使用CrossValidator类的开销很大。

// Evaluate model on test instances and compute test error val evaluator = new BinaryClassificationEvaluator() .setLabelCol("label") val cv = new CrossValidator() .setEstimator(pipeline) .setEvaluator(evaluator) .setEstimatorParamMaps(paramGrid) .setNumFolds(10)

管道在参数网格上不断地爬行,自动完成了模型优化的过程:对于每一个ParamMap类,CrossValidator训练获得一个Estimator,而后用Evaluator来评价结果,而后用最好的ParamMap和整个数据集来训练最优的Estimator。

图片描述

// When fit is called, the stages are executed in order. // Fit will run cross-validation, and choose the best set of parameters //The fitted model from a Pipeline is an PipelineModel, which consists of fitted models and transformers val pipelineFittedModel = cv.fit(trainingData)

如今,咱们能够用管道训练获得的最优模型进行预测,将预测结果与标签作比较。预测结果取得了82%的准确率,相比以前78%的准确率有提升。

//  call tranform to make predictions on test data. The fitted model will use the best model found val predictions = pipelineFittedModel.transform(testData) val accuracy = evaluator.evaluate(predictions) Double = 0.8204386232104784 val rm2 = new RegressionMetrics( predictions.select("prediction", "label").rdd.map(x => (x(0).asInstanceOf[Double], x(1).asInstanceOf[Double]))) println("MSE: " + rm2.meanSquaredError) println("MAE: " + rm2.meanAbsoluteError) println("RMSE Squared: " + rm2.rootMeanSquaredError) println("R Squared: " + rm2.r2) println("Explained Variance: " + rm2.explainedVariance + "\n") MSE: 0.2575250836120402 MAE: 0.25752508361204013 RMSE Squared: 0.5074692932700856 R Squared: -0.1687988628287138 Explained Variance: 0.15466269952237702
相关文章
相关标签/搜索