Attention Cluster 模型html
视频分类问题在视频标签、监控、自动驾驶等领域有着普遍的应用,但它同时也是计算机视觉领域面临的一项重要挑战之一。网络
目前的视频分类问题大可能是基于 CNN 或者 RNN 网络实现的。众所周知,CNN 在图像领域已经发挥了重大做用。它具备很好的特征提取能力,经过卷积层和池化层,能够在图像的不一样区域提取特征。RNN 则在获取时间相关的特征方面有很强的能力。学习
Attention Cluster 在设计上仅利用了 CNN 模型,而没有使用 RNN,主要是基于视频的如下几个特色考虑:测试
图 1 视频帧的分析优化
首先,一段视频的连续帧经常有必定的类似性。在图 1(上)能够看到,除了击球的动做之外,不一样帧几乎是同样的。所以,对于分类,可能从总体上关注这些类似的特征就足够了,而没有必要去特地观察它们随着时间的细节变化。ui
其次,视频帧中的局部特征有时就足够表达出视频的类别。好比图 1(中),经过一些局部特征,如牙刷、水池,就可以分辨出『刷牙』这个动做。所以,对于分类问题,关键在于找到帧中的关键的局部特征,而非去找时间上的线索。google
最后,在一些视频的分类中,帧的时间顺序对于分类不必定是重要的。好比图 1(下),能够看到,虽然帧顺序被打乱,依然可以看出这属于『撑杆跳』这个类别。设计
基于以上考虑,该模型没有考虑时间相关的线索,而是使用了 Attention 机制。它有如下几点好处:3d
1. Attention 的输出本质上是加权平均,这能够避免一些重复特征形成的冗余。orm
2. 对于一些局部的关键特征,Attention 可以赋予其更高的权重。这样就可以经过这些关键的特征,提升分类能力。
3. Attention 的输入是任意大小的无序集合。无序这点知足咱们上面的观察,而任意大小的输入又可以提升模型的泛化能力。
固然,一些视频的局部特征还有一个特色,那就是它可能会由多个部分组成。好比图 1(下)的『撑杆跳』,跳、跑和着陆同时对这个分类起到做用。所以,若是只用单一的 Attention 单元,只能获取视频的单一关键信息。而若是使用多个 Attention 单元,就可以提取更多的有用信息。因而,Attention Cluster 就应运而生了!在实现过程当中,百度计算机视觉团队还发现,将不一样的 Attention 单元进行一次简单有效的『位移操做』(shifting operation),能够增长不一样单元的多样性,从而提升准确率。
接下来咱们看一下整个 Attention Cluster 的结构。
整个模型能够分为三个部分:
1. 局部特征提取。经过 CNN 模型抽取视频的特征。提取后的特征用 X 表示,如公式(1)所示:
(1)。X 的维度为 L,表明 L 个不一样的特征。
2. 局部特征集成。基于 Attention 来获取全局特征。Attention 的输出本质上至关于作了加权平均。如公式(2)所示,v 是一个 Attention 单元输出的全局特征,a 是权重向量,由两层全链接层组成,如公式(3)所示。实际实现中,v 的产生使用了 Shifting operation,如公式(4)所示,其中α和β是可学习的标量。它经过对每个 Attention 单元的输出添加一个独立可学习的线性变换处理后进行 L2-normalization,使得各 Attention 单元倾向于学习特征的不一样成分,从而让 Attention Cluster 能更好地学习不一样分布的数据,提升整个网络的学习表征能力。因为采用了 Attention clusters,这里会将各个 Attention 单元的输出组合起来,获得多个全局特征 g,如公式(5)所示。N 表明的是 clusters 的数量。
3. 全局特征分类。将多个全局特征拼接之后,再经过常规的全链接层和 Softmax 或 Sigmoid 进行最后的单标签或多标签分类。
用 PaddlePaddle 训练 Attention Cluster
PaddlePaddle 开源的 Attention Cluster 模型,使用了 2nd-Youtube-8M 数据集。该数据集已经使用了在 ImageNet 训练集上 InceptionV3 模型对特征进行了抽取。
若是运行该模型的样例代码,要求使用 PaddlePaddle Fluid V1.2.0 或以上的版本。
数据准备:首先请使用 Youtube-8M 官方提供的连接下载训练集和测试集,或者使用官方脚本下载。数据下载完成后,将会获得 3844 个训练数据文件和 3844 个验证数据文件(TFRecord 格式)。为了适用于 PaddlePaddle 训练,须要将下载好的 TFRecord 文件格式转成了 pickle 格式,转换脚本请使用 PaddlePaddle 提供的脚本 dataset/youtube8m/tf2pkl.py。
训练集:http://us.data.yt8m.org/2/frame/train/index.html
测试集:http://us.data.yt8m.org/2/frame/validate/index.html
官方脚本:https://research.google.com/youtube8m/download.html
模型训练:数据准备完毕后,经过如下方式启动训练(方法 1),同时咱们也提供快速启动脚本 (方法 2)
# 方法 1
# 方法 2
用户也可下载 Paddle Github 上已发布模型经过--resume 指定权重存放路径进行 finetune 等开发。
数据预处理说明: 模型读取 Youtube-8M 数据集中已抽取好的 rgb 和 audio 数据,对于每一个视频的数据,均匀采样 100 帧,该值由配置文件中的 seg_num 参数指定。
模型设置: 模型主要可配置参数为 cluster_nums 和 seg_num 参数。其中 cluster_nums 是 attention 单元的数量。当配置 cluster_nums 为 32, seg_num 为 100 时,在 Nvidia Tesla P40 上单卡可跑 batch_size=256。
训练策略:
采用 Adam 优化器,初始 learning_rate=0.001
训练过程当中不使用权重衰减
参数主要使用 MSRA 初始化
模型评估:可经过如下方式(方法 1)进行模型评估,一样咱们也提供了快速启动的脚本(方法 2):
# 方法 1
# 方法 2
使用 scripts/test/test_attention_cluster.sh 进行评估时,须要修改脚本中的--weights 参数指定须要评估的权重。
若未指定--weights 参数,脚本会下载已发布模型进行评估
模型推断:可经过以下命令进行模型推断:
模型推断结果存储于 AttentionCluster_infer_result 中,经过 pickle 格式存储。
若未指定--weights 参数,脚本会下载已发布模型 model 进行推断
模型精度:当模型取以下参数时,在 Youtube-8M 数据集上的指标为:
参数取值:
评估精度: