在咱们在MXnet中定义好symbol、写好dataiter而且准备好data以后,就能够开开心的去训练了。通常训练一个网络有两种经常使用的策略,基于model的和基于module的。今天,我想谈一谈他们的使用。html
1、Modelpython
按照老规矩,直接从官方文档里面拿出来的代码看一下:api
# configure a two layer neuralnetwork data = mx.symbol.Variable('data') fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) act1 = mx.symbol.Activation(fc1, name='relu1', act_type='relu') fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64) softmax = mx.symbol.SoftmaxOutput(fc2, name='sm') # create a model using sklearn-style two-step way #建立一个model model = mx.model.FeedForward( softmax, num_epoch=num_epoch, learning_rate=0.01) #开始训练 model.fit(X=data_set)
具体的API参照http://mxnet.io/api/python/model.html。网络
而后呢,model这部分就说完了。。。之因此这么快主要有两个缘由:框架
1.确实东西很少,通常都是查一查文档就能够了。ide
2.model的可定制性不强,通常咱们是不多使用的,经常使用的仍是module。函数
2、Module设计
Module真的是一个很棒的东西,虽然深刻了解后,你会以为“哇,好厉害,可是感受没什么鸟用呢”这种想法。。实际上我就有过,如今回想起来,从代码的设计和使用的角度来说,Module确实是一个很是好的东西,它能够为咱们的网络计算提升了中级、高级的接口,这样一来,就能够有不少的个性化配置让咱们本身来作了。htm
Module有四种状态:blog
1.初始化状态,就是显存尚未被分配,基本上啥都没作的状态。
2.binded,在把data和label的shape传到Bind函数里而且执行以后,显存就分配好了,能够准备好计算能力。
3.参数初始化。就是初始化参数
3.Optimizer installed 。就是传入SGD,Adam这种optimuzer中去进行训练
先上一个简单的代码:
import mxnet as mx # construct a simple MLP data = mx.symbol.Variable('data') fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) out = mx.symbol.SoftmaxOutput(fc3, name = 'softmax') # construct the module mod = mx.mod.Module(out) mod.bind(data_shapes=train_dataiter.provide_data, label_shapes=train_dataiter.provide_label) mod.init_params() mod.fit(train_dataiter, eval_data=eval_dataiter, optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, num_epoch=n_epoch)
分析一下:首先是定义了一个简单的MLP,symbol的名字就叫作out,而后能够直接用mx.mod.Module来建立一个mod。以后mod.bind的操做是在显卡上分配所需的显存,因此咱们须要把data_shapehe label_shape传递给他,而后初始化网络的参数,再而后就是mod.fit开始训练了。这里补充一下。fit这个函数咱们已经看见两次了,实际上它是一个集成的功能,mod.fit()实际上它内部的核心代码是这样的:
for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() for nbatch, data_batch in enumerate(train_data): if monitor is not None: monitor.tic() self.forward_backward(data_batch) #网络进行一次前向传播和后向传播 self.update() #更新参数 self.update_metric(eval_metric, data_batch.label) #更新metric if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params)
正是由于module里面咱们能够使用不少intermediate的interface,因此能够作出不少改进,举个最简单的例子:若是咱们的训练网络是大小可变怎么办? 咱们能够实现一个mutumodule,基本上就是,每次data的shape变了的时候,咱们就从新bind一下symbol,这样训练就能够照常进行了。
总结:实际上学一个框架的关键仍是使用它,要说诀窍的话也就是多看看源码和文档了,我写这些博客的目的,一是为了记录一些东西,二是让后来者少走一些弯路。因此有些东西不会说的很全。。