Spark数据挖掘-数据标准化

Spark数据挖掘-数据标准化

1 前言

特征数据标准化指的是对训练样本经过利用每一列的统计量将特征列转换为0均值单位方差的数据。 这是很是通用的数据预处理步骤。
例如:RBF核的支持向量机或者基于L1和L2正则化的线性模型在数据标准化以后效果会更好。
数据标准化可以改进优化过程当中数据收敛的速度,也能防止一些方差过大的变量特征对模型训练 产生过大的影响。
如何对数据标准化呢?公式也很是简单:新的列 = (老的列每个值 - 老的列平均值) / (老的列标准差)apache

2 数据准备

在标准化以前,Spark必须知道每一列的平均值,方差,具体怎么知道呢?
想法很简单,首先给 Spark的 StandardScaler 一批数据,这批数据以 org.apache.spark.mllib.feature.Vector 的形式提供给 StandardScaler。StandardScaler 对输入的数据进行 fit 即计算每一列的平均值,方差。 调度代码以下:微信

import org.apache.spark.SparkContext._
import org.apache.spark.mllib.feature.StandardScaler
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils

val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")

val scaler1 = new StandardScaler().fit(data.map(x => x.features))
val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features))

上面代码的本质是生成一个包含每一列均值和方差的 StandardScalarModel,具体解释一下 withMean 和 withStd 的含义:机器学习

  • withMean 若是值为true,那么将会对列中每一个元素减去均值(不然不会减)
  • withStd 若是值为true,那么将会对列中每一个元素除以标准差(不然不会除,这个值通常为 true,不然没有标准化没有意义) 因此上面两个参数都为 false 是没有意义的,模型什么都不会干,返回原来的值,这些将会在下面的代码中获得验证。

下面给出上面 fit 函数的源代码:ide

/**
  * 计算数据每一列的平均值标准差,将会用于以后的标准化.
  *
  * @param data The data used to compute the mean and variance to build the transformation model.
  * @return a StandardScalarModel
  */
 @Since("1.1.0")
 def fit(data: RDD[Vector]): StandardScalerModel = {
   // TODO: 若是 withMean 和 withStd 都为false,什么都不用干
   //计算基本统计
   val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
     (aggregator, data) => aggregator.add(data),
     (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
   //经过标准差,平均值获得模型
   new StandardScalerModel(
     Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))),
     summary.mean,
     withStd,
     withMean)
 }

从这里能够发现,若是你知道每一列的平均值和方差,直接经过 StandardScalarModel 构建模型就能够了,以下代码:函数

val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean)

3 数据标准化

准备工做作好了,下面真正标准化,调用代码也很是简单:学习

al data1 = data.map(x => (x.label, scaler1.transform(x.features)))

用模型对每一行 transform 就能够了,背后的原理也很是简单,代码以下:大数据

// 由于 `shift` 只是在 `withMean` 为真的分支中才使用, 因此使用了
 // `lazy val`. 注意:这里不想在每一次 `transform` 都计算一遍 shift.
 private lazy val shift: Array[Double] = mean.toArray

 /**
  * Applies standardization transformation on a vector.
  *
  * @param vector Vector to be standardized.
  * @return Standardized vector. If the std of a column is zero, it will return default `0.0`
  *         for the column with zero std.
  */
 @Since("1.1.0")
 override def transform(vector: Vector): Vector = {
   require(mean.size == vector.size)
   if (withMean) {
     // By default, Scala generates Java methods for member variables. So every time when
     // the member variables are accessed, `invokespecial` will be called which is expensive.
     // This can be avoid by having a local reference of `shift`.
     val localShift = shift
     vector match {
       case DenseVector(vs) =>
         val values = vs.clone()
         val size = values.size
         if (withStd) {
           var i = 0
           while (i < size) {
             values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
             i += 1
           }
         } else {
           var i = 0
           while (i < size) {
             values(i) -= localShift(i)
             i += 1
           }
         }
         Vectors.dense(values)
       case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
     }
   } else if (withStd) {
     vector match {
       case DenseVector(vs) =>
         val values = vs.clone()
         val size = values.size
         var i = 0
         while(i < size) {
           values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0)
           i += 1
         }
         Vectors.dense(values)
       case SparseVector(size, indices, vs) =>
         // For sparse vector, the `index` array inside sparse vector object will not be changed,
         // so we can re-use it to save memory.
         val values = vs.clone()
         val nnz = values.size
         var i = 0
         while (i < nnz) {
           values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0)
           i += 1
         }
         Vectors.sparse(size, indices, values)
       case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
     }
   } else {
     // Note that it's safe since we always assume that the data in RDD should be immutable.
     vector
   }
 }

标准化原理简单,代码也简单,可是做用不能小看。优化

我的微信公众号

欢迎关注本人微信公众号,会定时发送关于大数据、机器学习、Java、Linux 等技术的学习文章,并且是一个系列一个系列的发布,无任何广告,纯属我的兴趣。
Clebeg能量集结号ui

相关文章
相关标签/搜索