C# WPF调用python-tensorflow2深度学习模型
python在研究深度学习人工智能领域十分强大,但在工业项目开发中仍经常使用C#和C++来作软件,C++有Caffe深度学习框架,但C#尚且没有成熟的深度学习框架(有个tensroflow.net尚在开发中,有兴趣能够去研究研究)。如今实验室项目开发又要用C#,通过实践最终决定在C#端利用OpencvSharp4的DNN模块加载python端tensorflow2训练的模型进行预测,其速度还能够。
一 环境介绍
python:Python3.7 tensorflow2.1前端
c#: vs2017 .net framework 4.6.1python
二 tensorflow模型的训练和生成
1 加载数据训练模型
1.1 数据集采用猫狗二分类数据。
数据集网盘连接:连接:https://pan.baidu.com/s/15LR7-tgvglzwW9n4eFsFgg
提取码:iz64
1.2 建立图片数据输入管道
代码实现:
express
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import glob gpu = tf.config.experimental.list_physical_devices(device_type='GPU') tf.config.experimental.set_virtual_device_configuration( gpu[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)]) image_path = glob.glob('./datasets/dc/train/*.jpg') image_label = [int(path.split('\\')[1].split('.')[0]=='cat') for path in image_path] def get_image_data(path, label): image = tf.io.read_file(path) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.resize(image, (224, 224)) image = tf.cast(image, tf.float32)/255 label = tf.reshape(label, [1]) return image, label dataset = tf.data.Dataset.from_tensor_slices((image_path, image_label)) dataset = dataset.map(get_image_data) train_count = int(len(image_path)*0.8) test_count = len(image_path)-train_count train_dataset = dataset.skip(test_count) test_dataset = dataset.take(test_count) train_dataset = train_dataset.shuffle(len(image_path)).repeat().batch(BATCH_SIZE) test_dataset = test_dataset.batch(BATCH_SIZE)
1.3 搭建并训练模型最后保存模型及参数,保存的格式的.h5文件,最终的准确度基本在99%以上。
代码实现:
c#
MobileNet = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) model = tf.keras.Sequential() model.add(MobileNet) model.add(tf.keras.layers.GlobalAveragePooling2D()) model.add(tf.keras.layers.Dense(256, activation='relu')) model.add(tf.keras.layers.Dense(1, activation='sigmoid')) model.compile(optimizer='adam', loss=tf.keras.losses.binary_crossentropy, metrics=['acc']) model.fit(train_dataset, epochs=10, steps_per_epoch=train_count//BATCH_SIZE, validation_data=test_dataset, validation_steps=test_count//BATCH_SIZE) model.save('./model_h5/mobilenet.h5')
2 h5文件转pb
Opencv的DNN模块接收tensorflow模型文件为pb文件,先将h5文件转换成pb文件,在tensorflow2.0端完成文件类型的转换。
转换代码:
后端
#参数1为h5文件的路径,参数2为要将pb文件保存到那个文件夹的路径,最后一个参数为pb文件的名称 def convert_h5to_pb(h5_path, pb_path, pb_name): model = tf.keras.models.load_model(h5_path, compile=False) model.summary() full_model = tf.function(lambda Input: model(Input)) full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)) frozen_func = convert_variables_to_constants_v2(full_model) frozen_func.graph.as_graph_def() layers = [op.name for op in frozen_func.graph.get_operations()] print("-" * 50) print("Frozen model layers: ") for layer in layers: print(layer) print("-" * 50) print("Frozen model inputs: ") print(frozen_func.inputs) print("Frozen model outputs: ") print(frozen_func.outputs) tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=pb_path, name=pb_name, as_text=False
二 C#加载模型并预测
1 vs2017环境搭建
在项目属性中设置平台目标为x64,
目标框架选择.net framework 4.6.1,没有该框架的可去官网下载安装。
进入NuGet程序包管理界面,搜索并下载以下三个包,有可能因为网络问题没法下载,可根据提示网站进入下载。
网络
2 调用模型
1 xml前端界面app
<Window xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:d="http://schemas.microsoft.com/expression/blend/2008" xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006" xmlns:telerik="http://schemas.telerik.com/2008/xaml/presentation" x:Class="WpfApp1.MainWindow" mc:Ignorable="d" Title="猫狗分类" Height="300" Width="500" WindowStartupLocation="CenterScreen"> <Grid> <Grid.ColumnDefinitions> <ColumnDefinition Width="200"/> <ColumnDefinition/> </Grid.ColumnDefinitions> <Grid Grid.Column="0"> <Grid.RowDefinitions> <RowDefinition Height="1*"/> <RowDefinition Height="3*"/> </Grid.RowDefinitions> <telerik:RadButton x:Name="read_image" Content="读取图片" Click="Read_image_Click" Margin="70,15,50,15"/> <Grid Grid.Row="1"> <Grid.ColumnDefinitions> <ColumnDefinition Width="70"/> <ColumnDefinition/> </Grid.ColumnDefinitions> <Grid.RowDefinitions> <RowDefinition/> <RowDefinition/> <RowDefinition/> <RowDefinition Height="20"/> </Grid.RowDefinitions> <Label Content="得分:" HorizontalAlignment="Center" VerticalAlignment="Center"/> <TextBox x:Name="score" HorizontalAlignment="Left" VerticalAlignment="Center" Grid.Row="0" Grid.Column="1" Width="120"/> <Label Content="类别:" HorizontalAlignment="Center" VerticalAlignment="Center" Grid.Row="1"/> <TextBox x:Name="classes" HorizontalAlignment="Left" VerticalAlignment="Center" Grid.Row="1" Grid.Column="1" Width="120"/> <Label Content="时间:" HorizontalAlignment="Center" VerticalAlignment="Center" Grid.Row="2"/> <TextBox x:Name="time" HorizontalAlignment="Left" VerticalAlignment="Center" Grid.Row="2" Grid.Column="1" Width="120"/> </Grid> </Grid> <Border BorderBrush="Black" BorderThickness="1" Grid.Column="1" HorizontalAlignment="Center" Height="214" VerticalAlignment="Center" Width="265"> <Image x:Name="img"/> </Border> </Grid> </Window>
2 C#后端实现预测框架
//引入OpencvSharp和Dnn模块 using System; using System.Windows; using System.Windows.Media.Imaging; using OpenCvSharp.Dnn; using OpenCvSharp; using Microsoft.Win32; namespace WpfApp1 { /// <summary> /// MainWindow.xaml 的交互逻辑 /// </summary> public partial class MainWindow : System.Windows.Window { public MainWindow() { InitializeComponent(); } public void Dnn_Classification(Mat image) { String model_path = ".//mobilenet.pb";//模型路径 Net net = CvDnn.ReadNetFromTensorflow(model_path);//加载模型 if (net.Empty()) { MessageBox.Show("pd文件错误"); return; } Mat input_image = CvDnn.BlobFromImage(image, 1 / 255.0, new OpenCvSharp.Size(224, 224)); //图片归一化和resize net.SetInput(input_image); Mat result = net.Forward();//载入图片并前向计算 float result_score = result.Get<float>(0, 0);//得到计算结果 score.Text = result_score.ToString(); if (result_score >= 0.5) { classes.Text = "Cat"; } else { classes.Text = "Dog"; } } private void Read_image_Click(object sender, RoutedEventArgs e) { OpenFileDialog ofd = new OpenFileDialog(); ofd.InitialDirectory = @"C:\Users\LemonQiu\Desktop"; ofd.Filter = "JPG图片|*.jpg|PNG图片|*.png"; if (ofd.ShowDialog() == true) { img.Source = new BitmapImage(new Uri(ofd.FileName)); Mat image = Cv2.ImRead(ofd.FileName); System.Diagnostics.Stopwatch watch = new System.Diagnostics.Stopwatch(); watch.Start(); Dnn_Classification(image); watch.Stop(); TimeSpan timespan = watch.Elapsed; time.Text = (timespan.TotalMilliseconds).ToString() + "ms"; } else { MessageBox.Show("没有选择图片"); } } } }
三 最终效果
检测时间基本稳定在100ms每张。
学习