商品需求(y, 吨),价格(x1, 元),消费者收入(x2, 元)apache
y | x1 | x2 |
5 | 1 | 1 |
8 | 1 | 2 |
7 | 2 | 1 |
13 | 2 | 3 |
18 | 3 | 4 |
创建需求函数: y = ax1+bx2函数
package spark.regressionAnalysis /** * 线性回归, 创建商品价格与消费者输入之间的关系, * 预测价格 */ import org.apache.log4j.{Level, Logger} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD} import org.apache.spark.{SparkConf, SparkContext} object LinearRegression { val conf = new SparkConf() //建立环境变量 .setMaster("local") //设置本地化处理 .setAppName("LinearRegression")//设定名称 val sc = new SparkContext(conf) //建立环境变量实例 def main(args: Array[String]) { val data = sc.textFile("./src/main/spark/regressionAnalysis/lr.txt")//获取数据集路径 val parsedData = data.map { line => //开始对数据集处理 val parts = line.split('|') //根据逗号进行分区 LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(',').map(_.toDouble))) }.cache() //转化数据格式 //LabeledPoint, numIterations, stepSize val model = LinearRegressionWithSGD.train(parsedData, 2, 0.1) //创建模型 val result = model.predict(Vectors.dense(1, 3))//经过模型预测模型 println(model.weights) println(model.weights.size) println(result) //打印预测结果 } }
lr.txtspa
5|1,1 8|1,2 7|2,1 13|2,3 18|3,4