经过Spark进行ALS离线和Stream实时推荐

ALS简介

ALS是alternating least squares的缩写 , 意为交替最小二乘法;而ALS-WR是alternating-least-squares with weighted-λ -regularization的缩写,意为加权正则化交替最小二乘法。该方法经常使用于基于矩阵分解的推荐系统中。例如:将用户(user)对商品(item)的评分矩阵分解为两个矩阵:一个是用户对商品隐含特征的偏好矩阵,另外一个是商品所包含的隐含特征的矩阵。在这个矩阵分解的过程当中,评分缺失项获得了填充,也就是说咱们能够基于这个填充的评分来给用户最商品推荐了。
ALS is the abbreviation of squares alternating least, meaning the alternating least squares method; and the ALS-WR is alternating-least-squares with weighted- lambda -regularization acronym, meaning weighted regularized alternating least squares method. This method is often used in recommender systems based on matrix factorization. For example, the user (user) score matrix of item is decomposed into two matrices: one is the user preference matrix for the implicit features of the commodity, and the other is the matrix of the implied features of the commodity. In the process of decomposing the matrix, the score missing is filled, that is, we can give the user the most recommended commodity based on the filled score.java

ALS-WR算法,简单地说就是:
(数据格式为:userId, itemId, rating, timestamp )
1 对每一个userId随机初始化N(10)个factor值,由这些值影响userId的权重。
2 对每一个itemId也随机初始化N(10)个factor值。
3 固定userId,从userFactors矩阵和rating矩阵中分解出itemFactors矩阵。即[Item Factors Matrix] = [User Factors Matrix]^-1 * [Rating Matrix].
4 固定itemId,从itemFactors矩阵和rating矩阵中分解出userFactors矩阵。即[User Factors Matrix] = [Item Factors Matrix]^-1 * [Rating Matrix].
5 重复迭代第3,第4步,最后能够收敛到稳定的userFactors和itemFactors。
6 对itemId进行推断就为userFactors * itemId = rating value;对userId进行推断就为itemFactors * userId = rating value。git

Spark支持ML和MLLIB两种机器学习库,官方推荐的是ML, 由于ML功能更全面更灵活,将来会主要支持ML。github

 

ML实现ALS推荐:

/**
 * @author huangyueran
 * @category ALS-WR
 */
public class JavaALSExampleByMl {

    private static final Logger log = LoggerFactory.getLogger(JavaALSExampleByMl.class);

    public static class Rating implements Serializable {
        // 0::2::3::1424380312
        private int userId; // 0
        private int movieId; // 2
        private float rating; // 3
        private long timestamp; // 1424380312

        public Rating() {
        }

        public Rating(int userId, int movieId, float rating, long timestamp) {
            this.userId = userId;
            this.movieId = movieId;
            this.rating = rating;
            this.timestamp = timestamp;
        }

        public int getUserId() {
            return userId;
        }

        public int getMovieId() {
            return movieId;
        }

        public float getRating() {
            return rating;
        }

        public long getTimestamp() {
            return timestamp;
        }

        public static Rating parseRating(String str) {
            String[] fields = str.split("::");
            if (fields.length != 4) {
                throw new IllegalArgumentException("Each line must contain 4 fields");
            }
            int userId = Integer.parseInt(fields[0]);
            int movieId = Integer.parseInt(fields[1]);
            float rating = Float.parseFloat(fields[2]);
            long timestamp = Long.parseLong(fields[3]);
            return new Rating(userId, movieId, rating, timestamp);
        }
    }

    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local");
        JavaSparkContext jsc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(jsc);

        JavaRDD<Rating> ratingsRDD = jsc.textFile("data/sample_movielens_ratings.txt")
                .map(new Function<String, Rating>() {
                    public Rating call(String str) {
                        return Rating.parseRating(str);
                    }
                });
        Dataset<Row> ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class);
        Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); // //对数据进行分割,80%为训练样例,剩下的为测试样例。
        Dataset<Row> training = splits[0];
        Dataset<Row> test = splits[1];

        // Build the recommendation model using ALS on the training data
        ALS als = new ALS().setMaxIter(5) // 设置迭代次数
                .setRegParam(0.01) // //正则化参数,使每次迭代平滑一些,此数据集取0.1好像错误率低一些。
                .setUserCol("userId").setItemCol("movieId")
                .setRatingCol("rating");
        ALSModel model = als.fit(training); // //调用算法开始训练


        Dataset<Row> itemFactors = model.itemFactors();
        itemFactors.show(1500);
        Dataset<Row> userFactors = model.userFactors();
        userFactors.show();

        // Evaluate the model by computing the RMSE on the test data
        Dataset<Row> rawPredictions = model.transform(test); //对测试数据进行预测
        Dataset<Row> predictions = rawPredictions
                .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
                .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));

        RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating")
                .setPredictionCol("prediction");
        Double rmse = evaluator.evaluate(predictions);
        log.info("Root-mean-square error = {} ", rmse);

        jsc.stop();
    }
}

MLLIB实现ALS推荐:

/**
 * @category ALS
 * @author huangyueran
 *
 */
public class JavaALSExampleByMlLib {

	private static final Logger log = LoggerFactory.getLogger(JavaALSExampleByMlLib.class);

	public static void main(String[] args) {
		SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local[4]");
		JavaSparkContext jsc = new JavaSparkContext(conf);

		JavaRDD<String> data = jsc.textFile("data/sample_movielens_ratings.txt");

		JavaRDD<Rating> ratings = data.map(new Function<String, Rating>() {
			public Rating call(String s) {
				String[] sarray = StringUtils.split(StringUtils.trim(s), "::");
				return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]),
						Double.parseDouble(sarray[2]));
			}
		});

		// Build the recommendation model using ALS
		int rank = 10;
		int numIterations = 6;
		MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01);

		// Evaluate the model on rating data
		JavaRDD<Tuple2<Object, Object>> userProducts = ratings.map(new Function<Rating, Tuple2<Object, Object>>() {
			public Tuple2<Object, Object> call(Rating r) {
				return new Tuple2<Object, Object>(r.user(), r.product());
			}
		});

		// 预测的评分
		JavaPairRDD<Tuple2<Integer, Integer>, Double> predictions = JavaPairRDD
				.fromJavaRDD(model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD()
						.map(new Function<Rating, Tuple2<Tuple2<Integer, Integer>, Double>>() {
							public Tuple2<Tuple2<Integer, Integer>, Double> call(Rating r) {
								return new Tuple2<Tuple2<Integer, Integer>, Double>(
										new Tuple2<Integer, Integer>(r.user(), r.product()), r.rating());
							}
						}));

		JavaPairRDD<Tuple2<Integer, Integer>, Tuple2<Double, Double>> ratesAndPreds = JavaPairRDD
				.fromJavaRDD(ratings.map(new Function<Rating, Tuple2<Tuple2<Integer, Integer>, Double>>() {
					public Tuple2<Tuple2<Integer, Integer>, Double> call(Rating r) {
						return new Tuple2<Tuple2<Integer, Integer>, Double>(
								new Tuple2<Integer, Integer>(r.user(), r.product()), r.rating());
					}
				})).join(predictions);

		// 获得按照用户ID排序后的评分列表 key:用户id
		JavaPairRDD<Integer, Tuple2<Integer, Double>> fromJavaRDD = JavaPairRDD.fromJavaRDD(ratesAndPreds.map(
				new Function<Tuple2<Tuple2<Integer, Integer>, Tuple2<Double, Double>>, Tuple2<Integer, Tuple2<Integer, Double>>>() {

					public Tuple2<Integer, Tuple2<Integer, Double>> call(
							Tuple2<Tuple2<Integer, Integer>, Tuple2<Double, Double>> t) throws Exception {
						return new Tuple2<Integer, Tuple2<Integer, Double>>(t._1._1,
								new Tuple2<Integer, Double>(t._1._2, t._2._2));
					}
				})).sortByKey(false);
		
//		List<Tuple2<Integer,Tuple2<Integer,Double>>> list = fromJavaRDD.collect();
//		for(Tuple2<Integer,Tuple2<Integer,Double>> t:list){
//			System.out.println(t._1+":"+t._2._1+"===="+t._2._2);
//		}

		JavaRDD<Tuple2<Double, Double>> ratesAndPredsValues = ratesAndPreds.values();

		double MSE = JavaDoubleRDD.fromRDD(ratesAndPredsValues.map(new Function<Tuple2<Double, Double>, Object>() {
			public Object call(Tuple2<Double, Double> pair) {
				Double err = pair._1() - pair._2();
				return err * err;
			}
		}).rdd()).mean();

		try {
			FileUtils.deleteDirectory(new File("result"));
		} catch (IOException e) {
			e.printStackTrace();
		}

		ratesAndPreds.repartition(1).saveAsTextFile("result/ratesAndPreds");

		//为指定用户推荐10个商品(电影)
		Rating[] recommendProducts = model.recommendProducts(2, 10);
		log.info("get recommend result:{}",Arrays.toString(recommendProducts));

		// 为全部用户推荐TOP N个物品
		//model.recommendUsersForProducts(10);
		
		// 为全部物品推荐TOP N个用户
		//model.recommendProductsForUsers(10)
		
		model.userFeatures().saveAsTextFile("result/userFea");
		model.productFeatures().saveAsTextFile("result/productFea");
		log.info("Mean Squared Error = {}" , MSE);

	}

}

以上两种主要是经过Spark进行离线的ALS推荐。还有一种是经过Spark-Streaming流式计算,对像Kafka消息队列中,缓冲的实时数据进行在线(实时)计算。算法

 

Spark-Streaming进行ALS实时推荐:

经过Spark-Streaming进行ALS推荐仅仅是其中的一环。真实项目中还涉及了不少其余技术处理。sql

好比用户行为日志数据的埋点处理,经过flume来进行监控拉取,存储到hdfs中。经过kafka来进行海量行为数据的消费、缓冲。apache

以及经过Spark机器学习计算后生成的训练模型的离线存储,Web拉取模型进行缓存,对用户进行推荐等等。api

/**
 * @author huangyueran
 * @category 基于Spark-streaming、kafka的实时推荐模板DEMO 原系统中包含商城项目、logback、flume、hadoop
 * The real time recommendation template DEMO based on Spark-streaming and Kafka contains the mall project, logback, flume and Hadoop in the original system
 */
public final class SparkALSByStreaming {

    private static final Logger log = LoggerFactory.getLogger(SparkALSByStreaming.class);

    private static final String KAFKA_ADDR = "middleware:9092";
    private static final String TOPIC = "RECOMMEND_TOPIC";
    private static final String HDFS_ADDR = "hdfs://middleware:9000";

    private static final String MODEL_PATH = "/spark-als/model";


    //	基于Hadoop、Flume、Kafka、spark-streaming、logback、商城系统的实时推荐系统DEMO
    //	Real time recommendation system DEMO based on Hadoop, Flume, Kafka, spark-streaming, logback and mall system
    //	商城系统采集的数据集格式 Data Format:
    //	用户ID,商品ID,用户行为评分,时间戳
    //	UserID,ItemId,Rating,TimeStamp
    //	53,1286513,9,1508221762
    //	53,1172348420,9,1508221762
    //	53,1179495514,12,1508221762
    //	53,1184890730,3,1508221762
    //	53,1210793742,159,1508221762
    //	53,1215837445,9,1508221762

    public static void main(String[] args) {
        System.setProperty("HADOOP_USER_NAME", "root"); // 设置权限用户

        SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaDirectWordCount").setMaster("local[1]");

        final JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(6));

        Map<String, String> kafkaParams = new HashMap<String, String>(); // key是topic名称,value是线程数量
        kafkaParams.put("metadata.broker.list", KAFKA_ADDR); // 指定broker在哪
        HashSet<String> topicsSet = new HashSet<String>();
        topicsSet.add(TOPIC); // 指定操做的topic

        // Create direct kafka stream with brokers and topics
        // createDirectStream()
        JavaPairInputDStream<String, String> messages = KafkaUtils.createDirectStream(jssc, String.class, String.class,
                StringDecoder.class, StringDecoder.class, kafkaParams, topicsSet);

        JavaDStream<String> lines = messages.map(new Function<Tuple2<String, String>, String>() {
            public String call(Tuple2<String, String> tuple2) {
                return tuple2._2();
            }
        });

        JavaDStream<Rating> ratingsStream = lines.map(new Function<String, Rating>() {
            public Rating call(String s) {
                String[] sarray = StringUtils.split(StringUtils.trim(s), ",");
                return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]),
                        Double.parseDouble(sarray[2]));
            }
        });

        // 进行流推荐计算
        ratingsStream.foreachRDD(new VoidFunction<JavaRDD<Rating>>() {

            public void call(JavaRDD<Rating> ratings) throws Exception {
                //  获取到原始的数据集
                SparkContext sc = ratings.context();

                RDD<String> textFileRDD = sc.textFile(HDFS_ADDR + "/flume/logs", 3); // 读取原始数据集文件
                JavaRDD<String> originalTextFile = textFileRDD.toJavaRDD();

                final JavaRDD<Rating> originaldatas = originalTextFile.map(new Function<String, Rating>() {
                    public Rating call(String s) {
                        String[] sarray = StringUtils.split(StringUtils.trim(s), ",");
                        return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]),
                                Double.parseDouble(sarray[2]));
                    }
                });
                log.info("========================================");
                log.info("Original TextFile Count:{}", originalTextFile.count()); // HDFS中已经存储的原始用户行为日志数据
                log.info("========================================");

                //  将原始数据集和新的用户行为数据进行合并
                JavaRDD<Rating> calculations = originaldatas.union(ratings);

                log.info("Calc Count:{}", calculations.count());

                // Build the recommendation model using ALS
                int rank = 10; // 模型中隐语义因子的个数
                int numIterations = 6; // 训练次数

                // 获得训练模型
                if (!ratings.isEmpty()) { // 若是有用户行为数据
                    MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(calculations), rank, numIterations, 0.01);
                    //  判断文件是否存在,若是存在 删除文件目录
                    Configuration hadoopConfiguration = sc.hadoopConfiguration();
                    hadoopConfiguration.set("fs.defaultFS", HDFS_ADDR);
                    FileSystem fs = FileSystem.get(hadoopConfiguration);
                    Path outpath = new Path(MODEL_PATH);
                    if (fs.exists(outpath)) {
                        log.info("########### 删除" + outpath.getName() + " ###########");
                        fs.delete(outpath, true);
                    }

                    // 保存model
                    model.save(sc, HDFS_ADDR + MODEL_PATH);

                    //  读取model
                    MatrixFactorizationModel modelLoad = MatrixFactorizationModel.load(sc, HDFS_ADDR + MODEL_PATH);
                    // 为指定用户推荐10个商品(电影)
                    for(int userId=0;userId<30;userId++){ // streaming_sample_movielens_ratings.txt
                        Rating[] recommendProducts = modelLoad.recommendProducts(userId, 10);
                        log.info("get recommend result:{}", Arrays.toString(recommendProducts));
                    }
                }

            }
        });

        // ==========================================================================================

        jssc.start();
        try {
            jssc.awaitTermination();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }


        // Local Model
        try {
            Thread.sleep(10000000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        // jssc.stop();
        // jssc.close();
    }

}

用户行为数据集

商城系统采集的数据集格式 Data Format:
用户ID,商品ID,用户行为评分,时间戳
UserID,ItemId,Rating,TimeStamp
53,1286513,9,1508221762
53,1172348420,9,1508221762
53,1179495514,12,1508221762
53,1184890730,3,1508221762
53,1210793742,159,1508221762
53,1215837445,9,1508221762缓存

 

maven依赖

<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core_2.10 -->
		<dependency>
			<groupId>org.apache.spark</groupId>
			<artifactId>spark-core_2.10</artifactId>
			<version>2.2.0</version>
		</dependency>
		<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib_2.10 -->
		<dependency>
			<groupId>org.apache.spark</groupId>
			<artifactId>spark-mllib_2.10</artifactId>
			<version>2.2.0</version>
		</dependency>
		<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql_2.10 -->
		<dependency>
			<groupId>org.apache.spark</groupId>
			<artifactId>spark-sql_2.10</artifactId>
			<version>2.2.0</version>
		</dependency>
		<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-streaming_2.10 -->
		<dependency>
			<groupId>org.apache.spark</groupId>
			<artifactId>spark-streaming_2.10</artifactId>
			<version>2.2.0</version>
		</dependency>
		<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-streaming-kafka_2.10 -->
		<dependency>
			<groupId>org.apache.spark</groupId>
			<artifactId>spark-streaming-kafka_2.10</artifactId>
			<version>1.6.3</version>
		</dependency>
		<!-- https://mvnrepository.com/artifact/log4j/log4j -->
		<dependency>
			<groupId>log4j</groupId>
			<artifactId>log4j</artifactId>
			<version>1.2.17</version>
		</dependency>
		<dependency>
			<groupId>org.slf4j</groupId>
			<artifactId>slf4j-api</artifactId>
			<version>1.7.12</version>
		</dependency>
		<dependency>
			<groupId>org.slf4j</groupId>
			<artifactId>slf4j-log4j12</artifactId>
			<version>1.7.12</version>
		</dependency>

以上代码以及数据集能够去Github上的项目找到dom

https://github.com/huangyueranbbc/Spark_ALS 机器学习

相关文章
相关标签/搜索