首页 文章详情

知乎 | 写深度学习代码需要遵守哪些顺序?

机器学习实验室 | 678 2021-12-09 19:50 0 0 0
UniSMS (合一短信)

来源 | 知乎问答

地址 | https://www.zhihu.com/question/498167513

本文仅作学术分享,若侵权请联系后台删文处理


01

回答一:作者-三思但不犹豫

前段时间刚重写了一个 dl 任务,在此说下心得体会:

  1. 顺序上,先 dataset,检查基本的 transform,再搭 model,构建 head 和 loss,就可以把一个基础的、可以跑的网络就能跑起来了(这点很重要);
  2. 可视化很重要,如果是本地开发机,善用 cv.imshow 直观、便捷地可视化处理的结果;
  3. 一个基础的 train/inference 流程跑通后,分别构建 1 张、10 张的数据用于 debug,确保任意改动后,可以 overfit;
  4. 调试代码阶段避免随机性、避免数据增强,一定用 tensorboard 之类的工具观察 loss 下降是否合理;
  5. 一般数据集最好处理成 coco 的格式,我的任务跟传统任务不太一样,但也尽量仿照 coco 来设计,写 dataset 的时候可以参考开源实现;
  6. 善用开源框架,比如 Open-MMLab,Detectron2 之类的,好处是方便实验,在框架里写不容易出现难以察觉的 bug,坏处是开源框架为了适配各种网络,代码复杂程度会高一点,建议从第一版入手了解框架,然后基于最新的一边阅读一边开发。

最后,想要更稳健的开发流程,参考 Karpathy 大神的:

http://karpathy.github.io/2019/04/25/recipe/

02

回答二:作者-捡到一束光

先给结论:以我写了两三年pytorch代码的经验而言,比较好的顺序是先写model,再写dataset,最后写train

在讨论码组件的具体顺序前,我们先分析每一个组件背后的目的和逻辑。

model构成了整个深度学习训练与推断系统骨架,也确定了整个AI模型的输入和输出格式。对于视觉任务,模型架构多为卷积神经网络或是最新的ViT模型;对于NLP任务,模型架构多为Transformer以及Bert;对于时间序列预测,模型架构多为RNNLSTM。不同的model对应了不同的数据输入格式,如ResNet一般是输入多通道二维矩阵,而ViT则需要输入带有位置信息的图像patchs。确定了用什么样的model后,数据的输入格式也就确定下来。根据确定的输入格式,我们才能构建对应的dataset

dataset构建了整个AI模型的输入与输出格式。在写作dataset组件时,我们需要考虑数据的存储位置与存储方式,如数据是否是分布式存储的,模型是否要在多机多卡的情况下运行,读写速度是否存在瓶颈,如果机械硬盘带来了读写瓶颈则需要将数据预加载进内存等。在写dataset组件时,我们也要反向微调model组件。例如,确定了分布式训练的数据读写后,需要用nn.DataParallel或者nn.DistributedDataParallel等模块包裹model,使模型能够在多机多卡上运行。此外,dataset组件的写作也会影响训练策略,这也为构建train组件做了铺垫。比如根据显存大小,我们需要确定相应的BatchSize,而BatchSize则直接影响学习率的大小。再比如根据数据的分布情况,我们需要选择不同的采样策略进行Feature Balance,而这也会体现在训练策略中。

train构建了模型的训练策略以及评估方法,它是最重要也是最复杂的组件。先构建modeldataset可以添加限制,减少train组件的复杂度。在train组件中,我们需要根据训练环境(单机多卡,多机多卡或是联邦学习)确定模型更新的策略,以及确定训练总时长epochs优化器的类型,学习率的大小与衰减策略,参数的初始化方法,模型损失函数。此外,为了对抗过拟合,提升泛化性,还需要引入合适的正则化方法,如DropoutBatchNormL2-RegularizationData Augmentation等。有些提升泛化性能的方法可以直接在train组件中实现(如添加L2-RegMixup),有些则需要添加进model中(如DropoutBatchNorm),还有些需要添加进dataset中(如Data Augmentation)。此处安利一下我们的专栏教程:数据增广的方法与代码实现(https://zhuanlan.zhihu.com/p/439206910)。
此外,train还需要记录训练过程的一些重要信息,并将这些信息可视化出来,比如在每个epoch上记录训练集的平均损失以及测试集精度,并将这些信息写入tensorboard,然后在网页端实时监控。在构建train组件中,我们需要随时根据模型表现进行参数微调,并根据结果改进modeldataset两个组件。

最后,我想分享两个我们组自己编写的,给新同学上手使用的深度学习Project。它们都采用model-dataset-train的顺序进行构建,实现了单机多卡,联邦学习等训练环境:
  • 在Cifar10与Cifar100上采用各种ResNet,以Mixup作为数据增广策略,实现监督分类与无监督学习(https://github.com/FengHZ/mixupfamily)。关于数据增广策略Mixup的科普也可以移步我们的专栏Mixup的一个综述(https://zhuanlan.zhihu.com/p/439205252)。
  • 在5种Bencnmark数据集上实现联邦迁移学习(https://github.com/FengHZ/KD3A)。
03

回答三:作者-芙兰朵露

作为data driven的学科,不同的AI model适合不同的数据类型,选择用哪个模型是基于你的数据长什么样来决定的。初学者知道用CNN处理图片,用RNN处理时间序列/语言,但这些都是最基础的工作,真正体现水平的是根据数据的性质来选择合适的细分模型。比如稀疏图像需要用Sparse CNN,语言Transformer效果比较好,但对某些特殊的时间序列RNN也有奇效。
接下来还有很多技术细节,比如需不需要数据增强?需不需要标签平滑?需不需要残差链接?需不需要多loss,如果需要如何平衡?需不需要解释模型?我甚至没有提到超参数,因为超参数是锦上添花而不是雪中送炭。只要没有明确的信息瓶颈,超参数对模型的影响是很小的。
上面提到的这些问题不需要全想明白,但心里要大致有个谱,至少也要知道这些问题是可能影响你的训练结果的,这其实需要相当的阅读和积累。这样之后出了问题才知道去哪里debug。
然后你就可以开始写了。这些问题想明白之后,其实先写哪个part已经不重要了,因为你的心中已经有了一个picture,先把这个picture给sketch下来,然后开始跑,第一遍效果肯定不好,但你要根据输出的结果大致判断哪个部分出了问题,然后针对性地去改进。这一步真的没什么好办法,很多时候其实是直觉,做多了自然就知道了。训练模型-发现问题-修改模型-再训练,就像炼丹一样,经过无数遍的抟炼,才能得到最后的金丹。
其实我洋洋洒洒说了这么多,本质不过是几个字:解决问题的能力。making things to work几乎是机器学习中最重要的能力了,而这种能力就是在日常的积累和训练中反复磨练出来的,成功的路上没有捷径,加油吧。

往期精彩:

 时隔一年!深度学习语义分割理论与代码实践指南.pdf第二版来了!

 基于 docker 和 Flask 的深度学习模型部署!

 新书预告 | 《机器学习公式推导与代码实现》出版在即!

good-icon 0
favorite-icon 0
收藏
回复数量: 0
    暂无评论~~
    Ctrl+Enter