什么是迁移学习?迁移学习的策略、步骤、区别和概念

发布:2022-12-06 15:05:35
阅读:5383
作者:网络整理
分享:复制链接

迁移学习帮助我们从老的机器学习任务中获得可以被复用的训练模型,只需要少量数据就能重新应用于新的训练任务。如今,自然语言处理和图像识别等领域被认为是迁移学习研究的热点领域。本文就来详细了解什么是迁移学习。

经典迁移学习策略

根据任务的领域、数据的可用性,应用不同的迁移学习策略和技术。

1、归纳迁移学习

归纳迁移学习要求源域和目标域相同,尽管模型处理的具体任务不同。这些算法尝试使用来自源模型的知识并将其应用于改进目标任务。预训练模型已经具有领域特征方面的专业知识,并且比从头开始训练会处于更好的起点。

根据源域是否包含标记数据,归纳迁移学习进一步分为两个子类。这些分别包括多任务学习和自学学习。

2、转导迁移学习

源任务和目标任务的领域不完全相同但相互关联的场景可使用转导迁移学习策略。人们可以得出源任务和目标任务之间的相似性。这些场景通常在源域中有大量标记数据,而目标域中只有未标记数据。

3、无监督迁移学习

无监督迁移学习类似于归纳迁移学习。唯一的区别是算法侧重于无监督任务,并且在源任务和目标任务中都涉及未标记的数据集。

4、基于域的相似性并独立于训练的数据样本类型的策略

  • 同构迁移学习

开发并提出了同构迁移学习方法来处理域具有相同特征空间的情况。在同构迁移学习中,域在边际分布上只有微小的差异。这些方法通过纠正样本选择偏差或协变量偏移来调整域。

  • 异构迁移学习

异构迁移学习方法旨在解决具有不同特征空间的源域和目标域的问题以及不同数据分布和标签空间等其他问题。异构迁移学习应用于跨领域任务,例如跨语言文本分类、文本到图像分类等。

迁移学习6个步骤

1.获取预训练模型

第一步是根据任务选择我们希望保留的预训练模型作为我们训练的基础。迁移学习需要预训练源模型的知识与目标任务域之间的强相关性才能兼容。

2.创建基础模型

基础模型是在第一步中选择与任务密切相关的架构,可能存在这样一种情况,基础模型在最终输出层中的神经元数量超过用例中所需的数量。在这种情况下,需要移除最终输出层并进行相应更改。

3.冻结起始层

冻结预训练模型的起始层对于避免使模型学习基本特征的至关重要。如果不冻结初始层,将失去所有已经发生的学习。这与从头开始训练模型没有什么不同,会导致浪费时间、资源等。

4.添加新的可训练层

从基础模型中重用的唯一知识是特征提取层。需要在特征提取层之上添加额外的层来预测模型的特殊任务。这些通常是最终的输出层。

5.训练新层

预训练模型的最终输出很可能与我们想要的模型输出不同,在这种情况下,必须使用新的输出层来训练模型。

6.微调模型

为了提高模型的性能。微调涉及解冻基础模型的某些部分,并以非常低的学习率在整个数据集上再次训练整个模型。低学习率将提高模型在新数据集上的性能,同时防止过度拟合。

传统机器学习与迁移学习的区别

1.传统机器学习模型需要从头开始训练,计算量大,需要大量数据才能达到高性能。另一方面,迁移学习计算效率高,有助于使用小数据集获得更好的结果。

2.传统机器学习采用孤立的训练方法,每个模型都针对特定目的进行独立训练,不依赖于过去的知识。与此相反,迁移学习使用从预训练模型中获取的知识来处理任务。

3.迁移学习模型比传统的ML模型更快地达到最佳性能。这是因为利用来自先前训练的模型的知识(特征、权重等)的模型已经理解了这些特征。它比从头开始训练神经网络更快。

深度迁移学习的概念

许多模型预训练的神经网络和模型构成了深度学习背景下迁移学习的基础,这被称为深度迁移学习。

要了解深度学习模型的流程,必须了解它们的组成部分。深度学习系统是分层架构,可以在不同层学习不同的特征。初始层编译更高级别的功能,随着我们深入网络,这些功能会缩小到细粒度的功能。

这些层最终连接到最后一层以获得最终输出。这打开了使用流行的预训练网络的限制,无需将其最后一层作为其他任务的固定特征提取器。其关键思想是利用预训练模型的加权层来提取特征,但在使用新任务的新数据训练期间不更新模型的权重。

深度神经网络是分层结构,具有许多可调的超参数。初始层的作用是捕获通用特征,而后面的层更侧重于手头的明确任务。微调基础模型中的高阶特征表示以使其与特定任务更相关是有意义的。我们可以重新训练模型的某些层,同时在训练中保持一些冻结。

进一步提高模型性能的方法是重新训练或微调预训练模型顶层的权重,同时训练分类器。这将强制从模型源任务中学习到的通用特征图中更新权重。微调将允许模型在目标域中应用过去的知识并重新学习一些东西。

此外,应该尝试微调少数顶层而不是整个模型。前几层学习基本的通用的特征,这些特征可以泛化到几乎所有类型的数据。微调的目的是使这些专门的特征适应新的数据集,而不是覆盖通用的学习。

扫码进群
微信群
免费体验AI服务