跳到主要内容

迁移学习与特征提取

什么是迁移学习

在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型;然后利用这个学习到的模型来对测试文档进行分类与预测。然而,我们看到机器学习算法在当前的Web挖掘研究中存在着一个关键的问题:一些新出现的领域中的大量训练数据非常难得到。我们看到Web应用领域的发展非常快速。大量新的领域不断涌现,从传统的新闻,到网页,到图片,再到博客、播客等等。传统的机器学习需要对每个领域都标定大量训练数据,这将会耗费大量的人力与物力。而没有大量的标注数据,会使得很多与学习相关研究与应用无法开展。其次,传统的机器学习假设训练数据与测试数据服从相同的数据分布。然而,在许多情况下,这种同分布假设并不满足。通常可能发生的情况如训练数据过期。这往往需要我们去重新标注大量的训练数据以满足我们训练的需要,但标注新数据是非常昂贵的,需要大量的人力与物力。从另外一个角度上看,如果我们有了大量的、在不同分布下的训练数据,完全丢弃这些数据也是非常浪费的。如何合理的利用这些数据就是迁移学习主要解决的问题。迁移学习可以从现有的数据中迁移知识,用来帮助将来的学习。迁移学习(Transfer Learning)的目标是将从一个环境中学到的知识用来帮助新环境中的学习任务。因此,迁移学习不会像传统机器学习那样作同分布假设。
迁移学习(Transfer learning) 顾名思义就是把已训练好的模型参数迁移到新的模型来帮助新模型训练。考虑到大部分数据或任务都是存在相关性的,所以通过迁移学习我们可以将已经学到的模型参数(也可理解为模型学到的知识)通过某种方式来分享给新模型从而加快并优化模型的学习效率不用像大多数网络那样从零学习。
模型的训练与预测: 
深度学习的模型可以划分为 训练 和 预测 两个阶段。  
训练 分为两种策略:一种是白手起家从头搭建模型进行训练,一种是通过预训练模型进行训练。  
预测 相对简单,直接用已经训练好的模型对数据集进行预测即可。  为什么要迁移学习?

  1. 站在巨人的肩膀上:前人花很大精力训练出来的模型在大概率上会比你自己从零开始搭的模型要强悍,没有必要重复造轮子。  
  2. 训练成本可以很低:如果采用导出特征向量的方法进行迁移学习,后期的训练成本非常低,用CPU都完全无压力,没有深度学习机器也可以做。  
  3. 适用于小数据集:对于数据集本身很小(几千张图片)的情况,从头开始训练具有几千万参数的大型神经网络是不现实的,因为越大的模型对数据量的要求越大,过拟合无法避免。这时候如果还想用上大型神经网络的超强特征提取能力,只能靠迁移学习。

迁移学习有几种方式

  1. Transfer Learning:冻结预训练模型的全部卷积层,只训练自己定制的全连接层。  
  2. Extract Feature Vector:先计算出预训练模型的卷积层对所有训练和测试数据的特征向量,然后抛开预训练模型,只训练自己定制的简配版全连接网络。  
  3. Fine-tune:冻结预训练模型的部分卷积层(通常是靠近输入的多数卷积层),训练剩下的卷积层(通常是靠近输出的部分卷积层)和全连接层。 

案例

要解决的问题

依然使用Cat vs Dog数据集,数据集详情请参考https://www.yuque.com/suanpan_doc/public/dx8z5f。

解决方案

在算盘的项目模板中,已经创建了一个简单的流程可以参考。
在项目模板中,双击keras教材案例中的迁移学习与特征提取: image.png创建出以下模板,模板中的组件可以在深度学习的pytorch的分类中找到:image.png

模板

可以看出整个模板分了五块:

  • 模型建立——需要有一个输入层,注意填写接收图片的大小与通道数,后面连接VGG16做特征提取,经过数据扁平层后,使用全连接层将模型输出为两个类别(猫狗)等。注意预训练模型需要被设置为特征提取器。
  • 数据载入——可以使用已经制作好的数据集组件Cats-vs-Dogs数据集组件,组件会直接将数据分为训练集,验证集和测试集,后面可以直接连接图片文件夹数据转换组件。
  • 模型训练——数据载入设置Batch Size,选择RMSprop作为参数寻优器,模型训练节点设置epoch数与Loss Function,最终生成训练好的模型文件
  • 预测评估——预测节点接收测试集数据与训练好的模型进行训练,生成预测图片文件夹与CSV文件(包含每个图片的index,真实值与预测值),CSV文件可以作为多分类评估的输入进行模型评估
  • 图片可视化——由于每张图片的大小不等,需要先将图片比例缩放之后再进行剪裁,案例中将图片变换为1501503的数据格式。

这样就实现了迁移学习的神经网络模型,你可以对其中的网络结构进行重新编辑,验证你的算法!