机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别

1、问题与解决方案html

经过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片、已经预先进行过处理,读取了各像素点的灰度值,并进行了标记。git

其中第0列是序号(不参与运算)、1-64列是像素值、65列是结果。github

咱们以64位像素值为特征进行多元分类,算法采用SDCA最大熵分类算法。算法

 

2、源码app

 先贴出所有代码:框架

namespace MulticlassClassification_Mnist
{
    class Program
    {
        static readonly string TrainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "optdigits-full.csv");
        static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "SDCA-Model.zip");

        static void Main(string[] args)
        {
            MLContext mlContext = new MLContext(seed: 1);
          
            TrainAndSaveModel(mlContext);
            TestSomePredictions(mlContext);

            Console.WriteLine("Hit any key to finish the app");
            Console.ReadKey();
        }
              

        public static void TrainAndSaveModel(MLContext mlContext)
        {
            // STEP 1: 准备数据
            var fulldata = mlContext.Data.LoadFromTextFile(path: TrainDataPath,
                    columns: new[]
                    {
                        new TextLoader.Column("Serial", DataKind.Single, 0),
                        new TextLoader.Column("PixelValues", DataKind.Single, 1, 64),
                        new TextLoader.Column("Number", DataKind.Single, 65)
                    },
                    hasHeader: true,
                    separatorChar: ','
                    );

            var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.2);
            var trainData = trainTestData.TrainSet;
            var testData = trainTestData.TestSet;

            // STEP 2: 配置数据处理管道        
            var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue);

            // STEP 3: 配置训练算法
            var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(labelColumnName: "Label", featureColumnName: "PixelValues");
            var trainingPipeline = dataProcessPipeline.Append(trainer)
              .Append(mlContext.Transforms.Conversion.MapKeyToValue("Number", "Label"));
            
            // STEP 4: 训练模型使其与数据集拟合
            Console.WriteLine("=============== Train the model fitting to the DataSet ===============");           

            ITransformer trainedModel = trainingPipeline.Fit(trainData);         


            // STEP 5:评估模型的准确性
            Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
            var predictions = trainedModel.Transform(testData);
            var metrics = mlContext.MulticlassClassification.Evaluate(data: predictions, labelColumnName: "Number", scoreColumnName: "Score");
            PrintMultiClassClassificationMetrics(trainer.ToString(), metrics);
         
            // STEP 6:保存模型              
            mlContext.ComponentCatalog.RegisterAssembly(typeof(DebugConversion).Assembly);
            mlContext.Model.Save(trainedModel, trainData.Schema, ModelPath);
            Console.WriteLine("The model is saved to {0}", ModelPath);
        }

        private static void TestSomePredictions(MLContext mlContext)
        {
            // Load Model           
            ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema);

            // Create prediction engine 
            var predEngine = mlContext.Model.CreatePredictionEngine<InputData, OutPutData>(trainedModel);

            //num 1
            InputData MNIST1 = new InputData()
            {               
                PixelValues = new float[] { 0, 0, 0, 0, 14, 13, 1, 0, 0, 0, 0, 5, 16, 16, 2, 0, 0, 0, 0, 14, 16, 12, 0, 0, 0, 1, 10, 16, 16, 12, 0, 0, 0, 3, 12, 14, 16, 9, 0, 0, 0, 0, 0, 5, 16, 15, 0, 0, 0, 0, 0, 4, 16, 14, 0, 0, 0, 0, 0, 1, 13, 16, 1, 0 }
            }; 
            var resultprediction1 = predEngine.Predict(MNIST1);
            resultprediction1.PrintToConsole();           
        }      
    }

    class InputData
    {
        public float Serial;
        [VectorType(64)]
        public float[] PixelValues;               
        public float Number;       
    }

    class OutPutData : InputData
    {  
        public float[] Score;  
    }   
}
View Code

  

3、分析机器学习

 总体流程和二元分类没有什么区别,下面解释一下有差别的两个地方。ide

 一、加载数据学习

      // STEP 1: 准备数据
            var fulldata = mlContext.Data.LoadFromTextFile(path: TrainDataPath,
                    columns: new[]
                    {
                        new TextLoader.Column("Serial", DataKind.Single, 0),
                        new TextLoader.Column("PixelValues", DataKind.Single, 1, 64),
                        new TextLoader.Column("Number", DataKind.Single, 65)
                    },
                    hasHeader: true,
                    separatorChar: ','
                    );

  此次咱们不是经过实体对象来加载数据,而是经过列信息来进行加载,其中PixelValues是特征值,Number是标签值。测试

 

二、训练通道

            // STEP 2: 配置数据处理管道        
            var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue)

// STEP 3: 配置训练算法 var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(labelColumnName: "Label", featureColumnName: "PixelValues");
var trainingPipeline = dataProcessPipeline.Append(trainer)
.Append(mlContext.Transforms.Conversion.MapKeyToValue(
"Number", "Label"));

// STEP 4: 训练模型使其与数据集拟合
ITransformer trainedModel
= trainingPipeline.Fit(trainData);

 首先经过MapValueToKey方法将Number值转换为Key类型,多元分类算法要求标签值必须是这种类型(相似枚举类型,二元分类要求标签为BOOL类型)。关于这个转换的缘由及编码方式,下面详细介绍。

 

4、键值类型编码与独热编码

 MapValueToKey功能是将(字符串)值类型转换为KeyTpye类型。

有时候某些输入字段用来表示类型(类别特征),但自己并无特别的含义,好比编号、电话号码、行政区域名称或编码等,这里须要把这些类型转换为1到一个整数如1-300来进行从新编号。

举个简单的例子,咱们进行图片识别的时候,目标结果多是“猫咪”、“小狗”、“人物”这些分类,须要把这些分类转换为一、二、3这样的整数。但本文的标签值自己就是一、二、3,为何还要转换呢?由于咱们这里的一二三其实不是数学意义上的数字,而是一种标志,能够理解为壹、贰、叁,因此要进行编码。

 MapKeyToValue和MapValueToKey相反,它把将键类型转换回其原始值(字符串)。就是说标签是文本格式,在运算前已经被转换为数字枚举类型了,此时预测结果为数字,经过MapKeyToValue将其结果转换为对应文本。

MapValueToKey通常是对标签值进行编码,通常不用于特征值,若是是特征值为字符串类型的,建议采用独热编码。独热编码即 One-Hot 编码,又称一位有效编码,其方法是使用N位状态寄存器来对N个状态进行编码,每一个状态都由他独立的寄存器位,而且在任意时候,其中只有一位有效。例如:

天然状态码为:0,1,2,3,4,5
独热编码为:000001,000010,000100,001000,010000,100000

怎么理解这个事情呢?举个例子,假如咱们要进行人的身材的分析,但咱们但愿加入地域特征,好比:“黑龙江”、“山东”、“湖南”、“广东”这种特征,但这种字符串机器学习是不认识的,必须转换为浮点数,刚才提到MapKeyToValue能够把字符串转换为数字,为何这里要采用独热编码呢?简单来讲,假设把地域名称转换为1到10几个数字,在欧氏几何中1到3的欧拉距离和1到9的欧拉距离是不等的,但通过独热编码后,任意两点间的欧拉距离都是相等的,而咱们这里的地域特征仅仅是想表达分类关系,彼此之间没有其余逻辑关系,因此应该采用独热编码。

 

5、进度调试

通常机器算法的数据拟合过程时间都比较长,有时程序跑了两个小时还没结束,也不知道还须要多长时间,着实让人着急,因此及时了解学习进度,是颇有必要的。

因为机器学习算法通常都有“递归直到收敛”这种操做,因此咱们是没有办法预先知道最终运算次数的,能作到的只能打印一些过程信息,看到程序在动,内心也有点底,当系统跑过一次以后,基本就大体知道须要多少次拟合了,后面再调试就能够大体了解进度了。补充一句,可不能够在测试阶段先减小样本数据进行快速调试,调试经过后再切换到全样本进行训练?其实不行,有时候样本数量小,可能会引发指标震荡,时间反而长了。

以前在Githube上看到有人经过MLContext.LOG事件来打印调试信息,我试了一下,发现无法控制筛选内容,不太方便,后来想到一个方法,就是新增一个自定义数据处理通道,这个通道不作具体事情,就打印调试信息。

类定义:

namespace MulticlassClassification_Mnist
{
    public class DebugConversionInput
    {
        public float Serial { get; set; }
    }
 
    public class DebugConversionOutput
    {
        public float DebugFeature { get; set; }
    }

    [CustomMappingFactoryAttribute("DebugConversionAction")]
    public class DebugConversion : CustomMappingFactory<DebugConversionInput, DebugConversionOutput>
    {       static long TotalCount = 0;

        public void CustomAction(DebugConversionInput input, DebugConversionOutput output)
        {
            output.DebugFeature = 1.0f;  
TotalCount++;
Console.WriteLine($"DebugConversion.CustomAction's debug info.TotalCount={TotalCount} "); } public override Action<DebugConversionInput, DebugConversionOutput> GetMapping() => CustomAction; } }

 使用方法:

 var dataProcessPipeline = mlContext.Transforms.CustomMapping(new DebugConversion().GetMapping(), contractName: "DebugConversionAction")
       .Append(...)
       .Append(mlContext.Transforms.Concatenate("Features", new string[] { "RealFeatures", "DebugFeature" }));

 经过CustomMapping加载咱们自定义的数据处理通道,因为数据集是懒加载(Lazy)的,因此必须把咱们自定义数据处理通道的输出加入为特征值,才能参与运算,而后算法在操做每一条数据时都会调用到CustomAction方法,这样就能够打印进度信息了。为了避免影响运算结果,咱们把这个数据处理通道的输出值固定为1.0f 。

 

6、资源获取

源码下载地址:https://github.com/seabluescn/Study_ML.NET

工程名称:MulticlassClassification_Mnist

点击查看机器学习框架ML.NET学习笔记系列文章目录

相关文章
相关标签/搜索