PyTorch可复现/重复实验的相关设置

机器学习与生成对抗网络

共 2818字,需浏览 6分钟

 · 2022-11-21

作者丨Alxander@知乎 编辑丨极市平台
来源丨https://zhuanlan.zhihu.com/p/448284000

确定性设置

1 随机种子设置

随机函数是最大的不确定性来源,包含了模型参数的随机初始化,样本的shuffle。

  • PyTorch 随机种子
  • python 随机种子
  • numpy 随机种子
# PyTorch
import torch
torch.manual_seed(0)

# python
import random
random.seed(0)

# Third part libraries
import numpy as np
np.random.seed(0)

CPU版本下,上述随机种子设置完成之后,基本就可实现实验的可复现了。

对于GPU版本,存在大量算法实现为不确定结果的算法,这种算法实现效率很高,但是每次返回的值会不完全一样。主要是由于浮点精度舍弃,不同浮点数以不同顺序相加,值可能会有很小的差异(小数点最末位)。

2 GPU算法确定性实现

GPU算法的不确定来源有两个

  • CUDA convolution benchmarking
  • nondeterministic algorithms

CUDA convolution benchmarking 是为了提升运行效率,对模型参数试运行后,选取最优实现。不同硬件以及benchmarking本身存在噪音,导致不确定性

nondeterministic algorithms:GPU最大优势就是并行计算,如果能够忽略顺序,就避免了同步要求,能够大大提升运行效率,所以很多算法都有非确定性结果的算法实现。通过设置use_deterministic_algorithms,就可以使得pytorch选择确定性算法。

# 不需要benchmarking
torch.backends.cudnn.benchmark=False

# 选择确定性算法
torch.use_deterministic_algorithms()

RUNTIME ERROR

对于一个PyTorch 的函数接口,没有确定性算法实现,只有非确定性算法实现,同时设置了use_deterministic_algorithms(),那么会导致运行时错误。比如:

>>> import torch
>>> torch.use_deterministic_algorithms(True)
>>> torch.randn(2, 2).cuda().index_add_(0, torch.tensor([0, 1]), torch.randn(2, 2))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: index_add_cuda_ does not have a deterministic implementation, but you set
'torch.use_deterministic_algorithms(True)'. ...

错误原因:

index_add没有确定性的实现,出现这种错误,一般都是因为调用了torch.index_select 这个api接口,或者直接调用tensor.index_add_。

解决方案:

自己定义一个确定性的实现,替换调用的接口。对于torch.index_select 这个接口,可以有如下的实现。

def deterministic_index_select(input_tensor, dim, indices):
"""
input_tensor: Tensor
dim: dim
indices: 1D tensor
"""
tensor_transpose = torch.transpose(x, 0, dim)
return tensor_transpose[indices].transpose(dim, 0)

样本读取随机

  1. 多线程情况下,设置每个线程读取的随机种子
  2. 设置样本generator
# 设置每个读取线程的随机种子
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
numpy.random.seed(worker_seed)
random.seed(worker_seed)

g = torch.Generator()
# 设置样本shuffle随机种子,作为DataLoader的参数
g.manual_seed(0)

DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
worker_init_fn=seed_worker,
generator=g,
)

参考文献

Reproducibility - PyTorch 1.10.1 documentation

torch.index_select - PyTorch 1.10.1 documentation



猜您喜欢:

深入浅出stable diffusion:AI作画技术背后的潜在扩散模型论文解读

 戳我,查看GAN的系列专辑~!
一顿午饭外卖,成为CV视觉的前沿弄潮儿!
最新最全100篇汇总!生成扩散模型Diffusion Models
ECCV2022 | 生成对抗网络GAN部分论文汇总
CVPR 2022 | 25+方向、最新50篇GAN论文
 ICCV 2021 | 35个主题GAN论文汇总
超110篇!CVPR 2021最全GAN论文梳理
超100篇!CVPR 2020最全GAN论文梳理

拆解组新的GAN:解耦表征MixNMatch

StarGAN第2版:多域多样性图像生成

附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 |《计算机视觉中的数学方法》分享

《基于深度学习的表面缺陷检测方法综述》

《零样本图像分类综述: 十年进展》

《基于深度神经网络的少样本学习综述》


《礼记·学记》有云:独学而无友,则孤陋而寡闻

欢迎加入 GAN/扩散模型 —交流微信群 !

扫描下面二维码,添加运营小妹好友,拉你进群。发送申请时,请备注,格式为:研究方向+地区+学校/公司+姓名如 扩散模型+北京+北航+吴彦祖



请备注格式:研究方向+地区+学校/公司+姓名



点击 一顿午饭外卖,成为CV视觉的前沿弄潮儿!,领取优惠券,加入 AI生成创作与计算机视觉 知识星球!

浏览 23
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报