导读:<a href='https://www.cnblogs.com/BeanHsiang/category/1218714.html' target='_blank'>ML.NET系列文章</a>html
ML.NET已经发布了v0.2版本,新增了聚类训练器,执行性能进一步加强。本文将介绍一种特殊的回归——泊松回归,并以NBA比赛得分预测的案例来演练。数据结构
前面的文章已提过,回归是用来预测连续值的,泊松回归是其中一种,其特殊在仅用于预测正整数,一般为计数类的数值。泊松分布是离散分布,因此特征值和标签值应为相同(或接近相同)时间间隔下的独立随机事件。性能
那么什么场景是符合计数,能够适用泊松回归呢?举几个例子,好比共享单车的调度,每一处地域中心,每隔1小时都要统计借车和还车数,根据这个统计咱们就能够预测下一个小时此处地域须要调配多少车辆才能知足须要。再好比,公司每月都有离职员工,那么人力资源部门就能够对月人员流失数进行计数,而后经过泊松回归来预测下个月的流失状况,以便提前采起措施作好招聘计划。lua
是否是有一点感受了,本次咱们用你们喜欢的NBA比赛得分来进行演练,由于比赛得分正好也是一种计数,也符合连续相同时间间隔(比赛时长的大致相近),比赛结果具备不肯定性,因此也是泊松回归大显身手的地方,为了易于理解,我将示范预测的是主场球队的得分。spa
本案例数据来源Kaggle.com,内容是<a href='https://www.kaggle.com/ionaskel/nba-games-stats-from-2014-to-2018' target='_blank'>NBA Team Game Stats from 2014 to 2018</a>,这份数据集收集了最近4年的NBA比赛,格式相似以下:code
"","Team","Game","Date","Home","Opponent","WINorLOSS","TeamPoints","OpponentPoints","FieldGoals","FieldGoalsAttempted","FieldGoals.","X3PointShots","X3PointShotsAttempted","X3PointShots.","FreeThrows","FreeThrowsAttempted","FreeThrows.","OffRebounds","TotalRebounds","Assists","Steals","Blocks","Turnovers","TotalFouls","Opp.FieldGoals","Opp.FieldGoalsAttempted","Opp.FieldGoals.","Opp.3PointShots","Opp.3PointShotsAttempted","Opp.3PointShots.","Opp.FreeThrows","Opp.FreeThrowsAttempted","Opp.FreeThrows.","Opp.OffRebounds","Opp.TotalRebounds","Opp.Assists","Opp.Steals","Opp.Blocks","Opp.Turnovers","Opp.TotalFouls" "1","ATL","1",2014-10-29,"Away","TOR","L","102","109","40","80",".500","13","22",".591","9","17",".529","10","42","26","6","8","17","24","37","90",".411","8","26",".308","27","33",".818","16","48","26","13","9","9","22" "2","ATL","2",2014-11-01,"Home","IND","W","102","92","35","69",".507","7","20",".350","25","33",".758","3","37","26","10","6","12","20","31","81",".383","12","32",".375","18","21",".857","11","44","25","5","5","18","26" "3","ATL","3",2014-11-05,"Away","SAS","L","92","94","38","92",".413","8","25",".320","8","11",".727","10","37","26","14","5","13","25","31","69",".449","5","17",".294","27","38",".711","11","50","25","7","9","19","15" "4","ATL","4",2014-11-07,"Away","CHO","L","119","122","43","93",".462","13","33",".394","20","26",".769","7","38","28","8","3","19","33","48","97",".495","6","21",".286","20","27",".741","11","51","31","6","7","19","30" "5","ATL","5",2014-11-08,"Home","NYK","W","103","96","33","81",".407","9","22",".409","28","36",".778","12","41","18","10","5","8","17","40","84",".476","8","21",".381","8","11",".727","13","44","26","2","6","15","29" "6","ATL","6",2014-11-10,"Away","NYK","W","91","85","27","71",".380","10","27",".370","27","28",".964","9","38","20","7","3","15","16","36","83",".434","6","26",".231","7","12",".583","11","40","23","4","2","15","26" "7","ATL","7",2014-11-12,"Home","UTA","W","100","97","39","76",".513","9","20",".450","13","18",".722","13","46","23","8","4","18","12","43","86",".500","5","23",".217","6","12",".500","8","30","28","12","8","11","17" "8","ATL","8",2014-11-14,"Home","MIA","W","114","103","42","75",".560","11","28",".393","19","23",".826","3","36","33","10","5","13","20","35","74",".473","10","21",".476","23","25",".920","5","32","27","10","3","14","20"
各字段以下: 比赛基本信息:主场Team,比赛场次序号Game,比赛日期Date,主队Home,客队Opponent,主队胜负Win or Loss。orm
比赛主客队技术数据:Team Points,Field Goals,Field Goals Attempted,Field Goals Percentage,3 Point Shots,3 Point Shots Attempted,3 Point Shots Percentage,Free Throws,Free Throws Attempted,Free Throws Percentage,Offensive Rebounds,Total Rebounds,Assists,Steals,Blocks,Turnovers,Total Fouls。htm
这些指标反映了主客队投篮出手次数、命中数、命中率,三分球的出手次数、命中数、命中率,罚球的出手次数、命中数、命中率,助攻,抢断,犯规等,这些都是咱们在看NBA时常见的统计。blog
因为只有这一份数据,为了分别用于训练、评估和预测,我将数据集按7:2:1的比例进行分割。事件
定义原始数据结构、预测数据结构,TeamPoints是主队得分,是本次示例要预测的目标,所以定义为标签字段。
public class Match { [Column(ordinal: "0")] public string Id; [Column(ordinal: "1")] public string Team; [Column(ordinal: "2")] public string Game; [Column(ordinal: "3")] public string Date; [Column(ordinal: "4")] public string Home; [Column(ordinal: "5")] public string Opponent; [Column(ordinal: "6")] public string WINorLOSS; [Column(ordinal: "7", name: "Label")] public float TeamPoints; [Column(ordinal: "8")] public float OpponentPoints; [Column(ordinal: "9")] public float FieldGoals; [Column(ordinal: "10")] public float FieldGoalsAttempted; [Column(ordinal: "11")] public float FieldGoals_; [Column(ordinal: "12")] public float X3PointShots; [Column(ordinal: "13")] public float X3PointShotsAttempted; [Column(ordinal: "14")] public float X3PointShots_; [Column(ordinal: "15")] public float FreeThrows; [Column(ordinal: "16")] public float FreeThrowsAttempted; [Column(ordinal: "17")] public float FreeThrows_; [Column(ordinal: "18")] public float OffRebounds; [Column(ordinal: "19")] public float TotalRebounds; [Column(ordinal: "20")] public float Assists; [Column(ordinal: "21")] public float Steals; [Column(ordinal: "22")] public float Blocks; [Column(ordinal: "23")] public float Turnovers; [Column(ordinal: "24")] public float TotalFouls; [Column(ordinal: "25")] public float Opp_FieldGoals; [Column(ordinal: "26")] public float Opp_FieldGoalsAttempted; [Column(ordinal: "27")] public float Opp_FieldGoals_; [Column(ordinal: "28")] public float Opp_3PointShots; [Column(ordinal: "29")] public float Opp_3PointShotsAttempted; [Column(ordinal: "30")] public float Opp_3PointShots_; [Column(ordinal: "31")] public float Opp_FreeThrows; [Column(ordinal: "32")] public float Opp_FreeThrowsAttempted; [Column(ordinal: "33")] public float Opp_FreeThrows_; [Column(ordinal: "34")] public float Opp_OffRebounds; [Column(ordinal: "35")] public float Opp_TotalRebounds; [Column(ordinal: "36")] public float Opp_Assists; [Column(ordinal: "37")] public float Opp_Steals; [Column(ordinal: "38")] public float Opp_Blocks; [Column(ordinal: "39")] public float Opp_Turnovers; [Column(ordinal: "40")] public float Opp_TotalFouls; } public class MatchPrediction { [ColumnName("Score")] public float TeamPoints; }
加载数据部分
const string DATA_PATH = "data/nba.games.stats.csv"; static ICollection<Match> LoadData() { var matches = new List<Match>(); using (var sr = new StreamReader(File.OpenRead(DATA_PATH))) { sr.ReadLine(); while (!sr.EndOfStream) { var line = sr.ReadLine(); var values = line.Split(","); var match = new Match { Id = values[0].Trim('"'), Team = values[1].Trim('"'), Game = values[2].Trim('"'), Date = values[3].Trim('"'), Home = values[4].Trim('"'), Opponent = values[5].Trim('"'), WINorLOSS = values[6].Trim('"'), TeamPoints = Convert.ToSingle(values[7].Trim('"')), OpponentPoints = Convert.ToSingle(values[8].Trim('"')), FieldGoals = Convert.ToSingle(values[9].Trim('"')), FieldGoalsAttempted = Convert.ToSingle(values[10].Trim('"')), FieldGoals_ = Convert.ToSingle(values[11].Trim('"')), X3PointShots = Convert.ToSingle(values[12].Trim('"')), X3PointShotsAttempted = Convert.ToSingle(values[13].Trim('"')), X3PointShots_ = Convert.ToSingle(values[14].Trim('"')), FreeThrows = Convert.ToSingle(values[15].Trim('"')), FreeThrowsAttempted = Convert.ToSingle(values[16].Trim('"')), FreeThrows_ = Convert.ToSingle(values[17].Trim('"')), OffRebounds = Convert.ToSingle(values[18].Trim('"')), TotalRebounds = Convert.ToSingle(values[19].Trim('"')), Assists = Convert.ToSingle(values[20].Trim('"')), Steals = Convert.ToSingle(values[21].Trim('"')), Blocks = Convert.ToSingle(values[22].Trim('"')), Turnovers = Convert.ToSingle(values[23].Trim('"')), TotalFouls = Convert.ToSingle(values[24].Trim('"')), Opp_FieldGoals = Convert.ToSingle(values[25].Trim('"')), Opp_FieldGoalsAttempted = Convert.ToSingle(values[26].Trim('"')), Opp_FieldGoals_ = Convert.ToSingle(values[27].Trim('"')), Opp_3PointShots = Convert.ToSingle(values[28].Trim('"')), Opp_3PointShotsAttempted = Convert.ToSingle(values[29].Trim('"')), Opp_3PointShots_ = Convert.ToSingle(values[30].Trim('"')), Opp_FreeThrows = Convert.ToSingle(values[31].Trim('"')), Opp_FreeThrowsAttempted = Convert.ToSingle(values[32].Trim('"')), Opp_FreeThrows_ = Convert.ToSingle(values[33].Trim('"')), Opp_OffRebounds = Convert.ToSingle(values[34].Trim('"')), Opp_TotalRebounds = Convert.ToSingle(values[35].Trim('"')), Opp_Assists = Convert.ToSingle(values[36].Trim('"')), Opp_Steals = Convert.ToSingle(values[37].Trim('"')), Opp_Blocks = Convert.ToSingle(values[38].Trim('"')), Opp_Turnovers = Convert.ToSingle(values[39].Trim('"')), Opp_TotalFouls = Convert.ToSingle(values[40].Trim('"')) }; matches.Add(match); } } return matches; }
训练、评估、预测部分
static PredictionModel<Match, MatchPrediction> Train(IEnumerable<Match> trainData) { var pipeline = new LearningPipeline(); pipeline.Add(CollectionDataSource.Create(trainData)); pipeline.Add(new ColumnDropper() { Column = new[] { "Id" } }); pipeline.Add(new CategoricalOneHotVectorizer("Team", "Game", "Date", "Home", "Opponent", "WINorLOSS")); pipeline.Add(new ColumnConcatenator("Features", "Team", "Game", "Date", "Home", "Opponent", "WINorLOSS", "OpponentPoints", "FieldGoals", "FieldGoalsAttempted", "FieldGoals_", "X3PointShots", "X3PointShotsAttempted", "X3PointShots_", "FreeThrows", "FreeThrowsAttempted", "FreeThrows_", "OffRebounds", "TotalRebounds", "Assists", "Steals", "Blocks", "Turnovers", "TotalFouls", "Opp_FieldGoals", "Opp_FieldGoalsAttempted", "Opp_FieldGoals_", "Opp_3PointShots", "Opp_3PointShotsAttempted", "Opp_3PointShots_", "Opp_FreeThrows", "Opp_FreeThrowsAttempted", "Opp_FreeThrows_", "Opp_OffRebounds", "Opp_TotalRebounds", "Opp_Assists", "Opp_Steals", "Opp_Blocks", "Opp_Turnovers", "Opp_TotalFouls")); pipeline.Add(new PoissonRegressor()); var model = pipeline.Train<Match, MatchPrediction>(); return model; } static void Evaluate(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> evaluateData) { var evaluator = new RegressionEvaluator(); var metric = evaluator.Evaluate(model, CollectionDataSource.Create(evaluateData)); Console.WriteLine("LossFn: {0}", metric.LossFn); Console.WriteLine("RSquared: {0}", metric.RSquared); Console.WriteLine("Rms: {0}", metric.Rms); } static void Predict(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> predictData) { var predicts = model.Predict(predictData); var results = predictData.Zip(predicts, (d, p) => (d, p)); foreach (var result in results) { Console.WriteLine("Date: {0}, Team: {1} Opponent: {2}, Score: {3}-{4}, Predict Home Score: {5}", result.d.Date, result.d.Team, result.d.Opponent, result.d.TeamPoints, result.d.OpponentPoints, result.p.TeamPoints); } }
最后是Main调用部分
static void Main(string[] args) { var data = LoadData(); var trainCount = Convert.ToInt32(data.Count * 0.7); var evaluateCount = Convert.ToInt32(data.Count * 0.2); var trainData = data.Take(trainCount); var evaluateData = data.Skip(trainCount).Take(evaluateCount); var predictData = data.Skip(trainCount + evaluateCount); var model = Train(trainData); Evaluate(model, evaluateData); Predict(model, predictData); }
执行结果
能够看到,最近的NBA比赛主队预测得分与真实结果对比,正确率已至关可观了,因为特征值都是比赛技术数据,用在之后的比赛时,可根据比赛进行的实时状况不断更新,即可愈来愈接近结果。 对球迷来讲这但是一件神器呀。想一想2018世界杯也立刻要开始了,保罗、阿喀琉斯什么的都弱爆了,相信小伙伴们也要尝试一下ML.NET的套路了吧,记得拿到历年完整的数据哟!
完整代码以下:
using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Models; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.IO; using System.Linq; namespace NBAPrediction { class Program { const string DATA_PATH = "data/nba.games.stats.csv"; static ICollection<Match> LoadData() { var matches = new List<Match>(); using (var sr = new StreamReader(File.OpenRead(DATA_PATH))) { sr.ReadLine(); while (!sr.EndOfStream) { var line = sr.ReadLine(); var values = line.Split(","); var match = new Match { Id = values[0].Trim('"'), Team = values[1].Trim('"'), Game = values[2].Trim('"'), Date = values[3].Trim('"'), Home = values[4].Trim('"'), Opponent = values[5].Trim('"'), WINorLOSS = values[6].Trim('"'), TeamPoints = Convert.ToSingle(values[7].Trim('"')), OpponentPoints = Convert.ToSingle(values[8].Trim('"')), FieldGoals = Convert.ToSingle(values[9].Trim('"')), FieldGoalsAttempted = Convert.ToSingle(values[10].Trim('"')), FieldGoals_ = Convert.ToSingle(values[11].Trim('"')), X3PointShots = Convert.ToSingle(values[12].Trim('"')), X3PointShotsAttempted = Convert.ToSingle(values[13].Trim('"')), X3PointShots_ = Convert.ToSingle(values[14].Trim('"')), FreeThrows = Convert.ToSingle(values[15].Trim('"')), FreeThrowsAttempted = Convert.ToSingle(values[16].Trim('"')), FreeThrows_ = Convert.ToSingle(values[17].Trim('"')), OffRebounds = Convert.ToSingle(values[18].Trim('"')), TotalRebounds = Convert.ToSingle(values[19].Trim('"')), Assists = Convert.ToSingle(values[20].Trim('"')), Steals = Convert.ToSingle(values[21].Trim('"')), Blocks = Convert.ToSingle(values[22].Trim('"')), Turnovers = Convert.ToSingle(values[23].Trim('"')), TotalFouls = Convert.ToSingle(values[24].Trim('"')), Opp_FieldGoals = Convert.ToSingle(values[25].Trim('"')), Opp_FieldGoalsAttempted = Convert.ToSingle(values[26].Trim('"')), Opp_FieldGoals_ = Convert.ToSingle(values[27].Trim('"')), Opp_3PointShots = Convert.ToSingle(values[28].Trim('"')), Opp_3PointShotsAttempted = Convert.ToSingle(values[29].Trim('"')), Opp_3PointShots_ = Convert.ToSingle(values[30].Trim('"')), Opp_FreeThrows = Convert.ToSingle(values[31].Trim('"')), Opp_FreeThrowsAttempted = Convert.ToSingle(values[32].Trim('"')), Opp_FreeThrows_ = Convert.ToSingle(values[33].Trim('"')), Opp_OffRebounds = Convert.ToSingle(values[34].Trim('"')), Opp_TotalRebounds = Convert.ToSingle(values[35].Trim('"')), Opp_Assists = Convert.ToSingle(values[36].Trim('"')), Opp_Steals = Convert.ToSingle(values[37].Trim('"')), Opp_Blocks = Convert.ToSingle(values[38].Trim('"')), Opp_Turnovers = Convert.ToSingle(values[39].Trim('"')), Opp_TotalFouls = Convert.ToSingle(values[40].Trim('"')) }; matches.Add(match); } } return matches; } static PredictionModel<Match, MatchPrediction> Train(IEnumerable<Match> trainData) { var pipeline = new LearningPipeline(); pipeline.Add(CollectionDataSource.Create(trainData)); pipeline.Add(new ColumnDropper() { Column = new[] { "Id" } }); pipeline.Add(new CategoricalOneHotVectorizer("Team", "Game", "Date", "Home", "Opponent", "WINorLOSS")); pipeline.Add(new ColumnConcatenator("Features", "Team", "Game", "Date", "Home", "Opponent", "WINorLOSS", "OpponentPoints", "FieldGoals", "FieldGoalsAttempted", "FieldGoals_", "X3PointShots", "X3PointShotsAttempted", "X3PointShots_", "FreeThrows", "FreeThrowsAttempted", "FreeThrows_", "OffRebounds", "TotalRebounds", "Assists", "Steals", "Blocks", "Turnovers", "TotalFouls", "Opp_FieldGoals", "Opp_FieldGoalsAttempted", "Opp_FieldGoals_", "Opp_3PointShots", "Opp_3PointShotsAttempted", "Opp_3PointShots_", "Opp_FreeThrows", "Opp_FreeThrowsAttempted", "Opp_FreeThrows_", "Opp_OffRebounds", "Opp_TotalRebounds", "Opp_Assists", "Opp_Steals", "Opp_Blocks", "Opp_Turnovers", "Opp_TotalFouls")); pipeline.Add(new PoissonRegressor()); var model = pipeline.Train<Match, MatchPrediction>(); return model; } static void Evaluate(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> evaluateData) { var evaluator = new RegressionEvaluator(); var metric = evaluator.Evaluate(model, CollectionDataSource.Create(evaluateData)); Console.WriteLine("LossFn: {0}", metric.LossFn); Console.WriteLine("RSquared: {0}", metric.RSquared); Console.WriteLine("Rms: {0}", metric.Rms); } static void Predict(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> predictData) { var predicts = model.Predict(predictData); var results = predictData.Zip(predicts, (d, p) => (d, p)); foreach (var result in results) { Console.WriteLine("Date: {0}, Team: {1} Opponent: {2}, Score: {3}-{4}, Predict Home Score: {5}", result.d.Date, result.d.Team, result.d.Opponent, result.d.TeamPoints, result.d.OpponentPoints, result.p.TeamPoints); } } static void Main(string[] args) { var data = LoadData(); var trainCount = Convert.ToInt32(data.Count * 0.7); var evaluateCount = Convert.ToInt32(data.Count * 0.2); var trainData = data.Take(trainCount); var evaluateData = data.Skip(trainCount).Take(evaluateCount); var predictData = data.Skip(trainCount + evaluateCount); var model = Train(trainData); Evaluate(model, evaluateData); Predict(model, predictData); } } public class Match { [Column(ordinal: "0")] public string Id; [Column(ordinal: "1")] public string Team; [Column(ordinal: "2")] public string Game; [Column(ordinal: "3")] public string Date; [Column(ordinal: "4")] public string Home; [Column(ordinal: "5")] public string Opponent; [Column(ordinal: "6")] public string WINorLOSS; [Column(ordinal: "7", name: "Label")] public float TeamPoints; [Column(ordinal: "8")] public float OpponentPoints; [Column(ordinal: "9")] public float FieldGoals; [Column(ordinal: "10")] public float FieldGoalsAttempted; [Column(ordinal: "11")] public float FieldGoals_; [Column(ordinal: "12")] public float X3PointShots; [Column(ordinal: "13")] public float X3PointShotsAttempted; [Column(ordinal: "14")] public float X3PointShots_; [Column(ordinal: "15")] public float FreeThrows; [Column(ordinal: "16")] public float FreeThrowsAttempted; [Column(ordinal: "17")] public float FreeThrows_; [Column(ordinal: "18")] public float OffRebounds; [Column(ordinal: "19")] public float TotalRebounds; [Column(ordinal: "20")] public float Assists; [Column(ordinal: "21")] public float Steals; [Column(ordinal: "22")] public float Blocks; [Column(ordinal: "23")] public float Turnovers; [Column(ordinal: "24")] public float TotalFouls; [Column(ordinal: "25")] public float Opp_FieldGoals; [Column(ordinal: "26")] public float Opp_FieldGoalsAttempted; [Column(ordinal: "27")] public float Opp_FieldGoals_; [Column(ordinal: "28")] public float Opp_3PointShots; [Column(ordinal: "29")] public float Opp_3PointShotsAttempted; [Column(ordinal: "30")] public float Opp_3PointShots_; [Column(ordinal: "31")] public float Opp_FreeThrows; [Column(ordinal: "32")] public float Opp_FreeThrowsAttempted; [Column(ordinal: "33")] public float Opp_FreeThrows_; [Column(ordinal: "34")] public float Opp_OffRebounds; [Column(ordinal: "35")] public float Opp_TotalRebounds; [Column(ordinal: "36")] public float Opp_Assists; [Column(ordinal: "37")] public float Opp_Steals; [Column(ordinal: "38")] public float Opp_Blocks; [Column(ordinal: "39")] public float Opp_Turnovers; [Column(ordinal: "40")] public float Opp_TotalFouls; } public class MatchPrediction { [ColumnName("Score")] public float TeamPoints; } }