开辟新视野之高层训练框架 PyTorch-Ignite

机器学习算法工程师

共 15611字,需浏览 32分钟

 · 2021-12-01

1 前言

在 CV 领域,一提到高层次封装的训练框架,大家可能马上会想到 MMCV,其是 OpenMMLab 中的基础库,用于提供所有上层库例如 MMDetection、MMSegmantation 等的一致性支持,功能强大,新 Feature 也非常多,读到这里不妨休息下点个 Star 吧

https://github.com/open-mmlab/mmcv

本文介绍另一个高层封装训练框架 Ignite, 其官方介绍是:PyTorch-Ignite 是一个可帮助在 PyTorch 中灵活透明地训练和评估神经网络的高级库可以发现 Ignite 对标的是 MMCV 和 Pytorch-Lighting,但是相比 Pytorch-Lighting 更加简单。本文对 Ignite 进行整体性分析希望大家能够开辟新视野,换个姿势了解其他训练框架的封装方式,而不要拘泥于某一种固定的开发模式,阻碍自身成长

由于 Ignite 内容比较多,本文分析会有侧重于整体分析,无法顾全所有内容。如果想了解的非常透彻,建议和我交流或者留言。

Github 地址:

https://github.com/pytorch/ignite

官方地址:

https://pytorch-ignite.ai/

Docs 地址:

Ignite Your Networks! — PyTorch-Ignite v0.4.7 Documentation

Guides 地址: 

https://pytorch-ignite.ai/how-to-guides/


本文代码比较多,手机端阅读体验不佳,建议采用电脑端查看或者后续移步知乎社区,知乎 ID: 深度眸


2 Ignite 特性分析

当前分析版本是 V0.4.7。

2.1 核心特性

其核心特性是:


  1. 比纯 PyTorch 更少的代码,同时确保最大程度的控制和简单性

  2. 库方法和没有程序控制反转 - 在需要的地方和时间使用 Ignite

  3. 用于指标、实验管理器和其他组件的可扩展 API


这里解释下控制反转 (Inversion of control) 。IoC 是一种设计思想,用于解决对象与对象实例化耦合问题,在 Spring 等大型应用程序框架中有着非常多的应用。控制反转是指把传统模式中需要自己通过实例化构造函数,或者通过工厂模式实例化的任务交给容器来避免强耦合。这种做法其实非常常见,和我们常说的依赖抽象而不是依赖实体非常类似。举个最简单的例子:

# 沙丁鱼class SardineFish:   pass # 石斑鱼class GrouperFish:    pass# 吃饭class Dining:   def __init__(self):       self.fish=SardineFish()   def eat(self):       return self.fish   dining=Dining()dining.eat()

今天想吃沙丁鱼,因此直接在 Dining 类中实例化了 SardineFish 类,这是一种非常强的耦合关系,一旦我明天想吃石斑鱼就麻烦了。控制反转可以解决上述问题,将鱼类具体实例化交给容器实例化,Dining 内部只是被动的获取类对象即可,不负责创建实例,而是交给容器类。自己需要主动实例化对象变为被动获取,依赖对象控制权被反转,不需要再考虑如何实例化其他依赖的类。

class SardineFish:   def name(self):       return 'SardineFish'
class GrouperFish: def name(self): return 'GrouperFish'
class Container: def __init__(self): self.fish_dict={} def bind(self,fish): self.fish_dict[fish.name]=fish def get(self,name): return self.fish_dict[name]
class Dining: def __init__(self,container): self.container= container
def eat(self,name): return self.container.get(name) container= Container()container.bind(SardineFish()) container.bind(GrouperFish())
dining= Dining(container)dining.eat('GrouperFish')

Ignite 没有程序控制反转,是因为他都是基于方法或者函数进行扩展开发,不存在对象和对象自己的实例化耦合问题。


作者指出其功能上主要特点是:


  1. 极其简单的训练引擎和事件系统

  2. 开箱即用的指标,可轻松评估模型

  3. 用于组成训练 pipeline、保存以及记录参数和指标的内置处理程序


事件是什么?可以简单理解为一个动作 action,例如保存权重就是一个事件, 通过丰富的事件系统可以实现灵活的无侵入的扩展功能。


需要指出的是,Ignite 文档非常多也非常全面,包括 Getting Started、Documentation、Additional Materials、Examples、Tutorials 和 Projects using Ignite 等等,如果你非常有兴趣,可以阅读相关 project,有非常多的实例。

2.2 主要特性

2.2.1 简化训练和验证循环

不再需要为 epoch 和 iterations 手动设置 for/while 循环,用户初始化的实例化引擎会自动处理和运行。这算是作为一个高层训练框架包装器的最基本要求了吧。


2.2.2 强大的 Event 事件和 Handler 处理器

Ignite 处理程序很酷的地方在于它们提供了无与伦比的灵活性(例如与回调相比)。处理程序可以是任何函数:例如 lambda、简单函数、类方法等。因此我们不需要从接口继承并覆盖其抽象方法,这可能会不必要地增加您的代码及其复杂性。

作为一个框架,最需要考虑的是扩展性,MMCV 和 Pytorch-Lighting 都提出了自己的扩展方式,ignite 扩展方式非常简洁,不需要继承并覆写某些抽象方法,而是可以传入任意函数。这也是 ignite 不同于其他两个框架的特点,后面会重点介绍。



(1) 随时执行任意数量的你想要的扩展功能

# 1 注入自定义的事件处理器trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))
# attach handler with args, kwargsmydata = [1, 2, 3, 4]logger = ...def on_training_ended(data): print(f"Training is ended. mydata={data}") # User can use variables from another scope logger.info("Training is ended") # 2 注入自定义的事件处理器trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)# call any number of functions on a single event# 3 注入自定义的事件处理器trainer.add_event_handler(Events.COMPLETED, lambda engine: print(engine.state.times))
# 4 注入自定义的事件处理器@trainer.on(Events.ITERATION_COMPLETED)def log_something(engine): print(engine.state.output)

上面例子写了几种 Ignite 支持的扩展开发方式,如果你已经熟悉了 MMCV 的 Hook 开发模式,那么上面例子含义非常容易理解。



(2) 内置事件过滤器

# run the validation every 5 epochs@trainer.on(Events.EPOCH_COMPLETED(every=5))def run_validation():    # run validation    # change some training variable once on 20th epoch@trainer.on(Events.EPOCH_STARTED(once=20))def change_training_variable():    # ...    # Trigger handler with customly defined frequency@trainer.on(Events.ITERATION_COMPLETED(event_filter=first_x_iters))def log_gradients():    # ...


事件过滤器是指基于过滤规则运行指定事件,例如每隔 20 个 epoch 验证一次,跳过前 n 次迭代等等。


(3) 一个事件并集操作共享多个 action

@trainer.on(Events.COMPLETED | Events.EPOCH_COMPLETED(every=10))def run_validation():    # ...

这是一种非常好的特性。


(4) 支持标准事件外的自定义事件


from ignite.engine import EventEnum
class BackpropEvents(EventEnum): BACKWARD_STARTED = 'backward_started' BACKWARD_COMPLETED = 'backward_completed' OPTIM_STEP_COMPLETED = 'optim_step_completed' def update(engine, batch): # ... loss = criterion(y_pred, y) engine.fire_event(BackpropEvents.BACKWARD_STARTED) loss.backward() engine.fire_event(BackpropEvents.BACKWARD_COMPLETED) optimizer.step() engine.fire_event(BackpropEvents.OPTIM_STEP_COMPLETED) # ... trainer = Engine(update)trainer.register_events(*BackpropEvents)
@trainer.on(BackpropEvents.BACKWARD_STARTED)def function_before_backprop(engine): # ...


2.2.3 开箱即用的评估指标


假设 Ignite 内置的事件无法满足我的需求则可以自定义事件,如上所示用户自定义了反向传播相关的事件,然后可以通过 register_events 注册从而生效。


目前已经支持了非常多评估指标,例如 Precision, Recall, Accuracy, Confusion Matrix, IoU 等等,当然用户也可以组合或者自定义新的评估指标

precision = Precision(average=False)recall = Recall(average=False)F1_per_class = (precision * recall * 2 / (precision + recall))F1_mean = F1_per_class.mean()  # torch mean methodF1_mean.attach(engine, "F1")



3 从一个典型而简单的例子说起


如果直接讲 Ignite 整体设计原则,可能很多人依然觉得难以理解,故先以一个非常简单的分类任务例子说明常用用法,从这个用法中可以说明 Ignite 的大部分特性和设计巧妙之处。


完整代码来自:https://pytorch-ignite.ai/tutorials/beginner/01-getting-started/#complete-code

3.1 初始化必备对象实例

model = Net().to(device)
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader( MNIST(download=True, root=".", transform=data_transform, train=True), batch_size=128, shuffle=True)
val_loader = DataLoader( MNIST(download=True, root=".", transform=data_transform, train=False), batch_size=256, shuffle=False)
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005)criterion = nn.CrossEntropyLoss()

初始化模型、train loader、val loader、optimizer 和 loss 计算类等。

3.2 初始化训练引擎

trainer = create_supervised_trainer(model, optimizer, criterion, device)

create_supervised_trainer 只是只是一个简单的帮助函数,内部实际上是初始化了一个训练 Engine,核心代码为:

def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:    model.train()    x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)    y_pred = model(x)    loss = loss_fn(y_pred, y)    if gradient_accumulation_steps > 1:        loss = loss / gradient_accumulation_steps    loss.backward()    if engine.state.iteration % gradient_accumulation_steps == 0:        optimizer.step()        optimizer.zero_grad()    return output_transform(x, y, y_pred, loss)    trainer = Engine(_update)

训练引擎在每一 step 时候都会调用 _update 函数进行推理、loss 计算和参数优化

3.3 定义评估流程

 1 定义评估指标val_metrics = {    "accuracy": Accuracy(),    "loss": Loss(criterion)}
# 2. 实例化两个新的 engine# 一个负责训练过程中的评估,一个负责验证过程中的评估train_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
log_interval = 100
# 3 自定义 handler,并通过 on 装饰器注入到 engine@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))def log_training_loss(engine): print(f"Epoch[{engine.state.epoch}], Iter[{engine.state.iteration}] Loss: {engine.state.output:.2f}")
@trainer.on(Events.EPOCH_COMPLETED)def log_training_results(trainer): train_evaluator.run(train_loader) metrics = train_evaluator.state.metrics print(f"Training Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")
@trainer.on(Events.EPOCH_COMPLETED)def log_validation_results(trainer): val_evaluator.run(val_loader) metrics = val_evaluator.state.metrics print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")
def score_function(engine): return engine.state.metrics["accuracy"] # 4 模型保存 handler model_checkpoint = ModelCheckpoint( "checkpoint", n_saved=2, filename_prefix="best", score_function=score_function, score_name="accuracy", global_step_transform=global_step_from_engine(trainer),) # 将模型保存 handler 注入到验证评估引擎中val_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model})

不同于我们常规的理解,其 Engine 不是一个,而是 3 个,每个 Engine 负责一个流程,分别是训练 Engine、训练中评估的 Engine 和验证中评估 Engine。训练 Engine 是用于控制训练的循环过程,train_evaluator 是对 train dataloader 进行评估,评估的指标就是前面定义的 val_metrics,val_evaluator 是对 val dataloader 进行评估。三个 engine 相互独立,但是实际上是通过 train engine 组织起来的。

3.4 插入 logger handler

tb_logger = TensorboardLogger(log_dir="tb-logger")
tb_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training", output_transform=lambda loss: {"batch_loss": loss},)
for tag, evaluator in [("training", train_evaluator), ("validation", val_evaluator)]: tb_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag=tag, metric_names="all", global_step_transform=global_step_from_engine(trainer), )

 tensorboard 非常重要,可以将 tb_logger 插入到任意一个或者多个 Engine 中,例如上面代码是插入到了每个 Engine,插入过程是通过 attach_output_handler 实现的,而 event_name 表示触发时机。

3.5 开启训练

trainer.run(train_loader, max_epochs=5)
tb_logger.close()

开启训练后,在特定时候会触发注入的事件 Handler,例如在每个 epoch 完成后,会进行训练集的评估和验证集的评估,并将所有评估指标保存到 Tensorboard 中。

通过上述的完整例子,大家应该有了第一直观感受,其将函数作为一等公民这一宗旨发挥到了最大化,除了 Engine 类外,其他功能都可以通过函数形式注册进去,实现丰富的扩展功能。


4 Ignite 整体分析

Ignite 主要要理解 Engine、State、Event 和 Handler 这 4 个概念,核心代码位于 ignite/engine/engine.py、ignite/engine/events.py,其关系如下所示:



Engine 负责一个完整的循环流程,可以是一个训练流程,也可以是一个验证流程,整个流程的状态都是通过 State 对象统一维护,而 Events 管理了所有支持的触发事件,如果有自定义事件可以通过 register_events 接口实现,事件触发后具体的任务执行是通过 Handler 对象负责,Handler 可以认为是 Hook 的升级版本,其更加灵活好用,各种扩展功能都可以 Handler 实现,例如 logger、checkpoint、metric 等等。Engine 运行流程本质就是 for 循环,然后在特定点位触发事件,执行 Handler 任务。

4.1 Engine

Engine 是运行流程的核心,但是又非常简单,其核心流程如下:

while epoch < max_epochs:    # run an epoch on data    data_iter = iter(data)    while True:        try:            batch = next(data_iter)            output = process_function(batch)            iter_counter += 1        except StopIteration:            data_iter = iter(data)
if iter_counter == epoch_length: break

可以看出就是一个典型的 for 循环,process_function 负责处理一个 epoch 的数据。一个典型的例子是:

def train_step(trainer, batch):    model.train()    optimizer.zero_grad()    x, y = prepare_batch(batch)    y_pred = model(x)    loss = loss_fn(y_pred, y)    loss.backward()    optimizer.step()    return loss.item()    trainer = Engine(train_step) # process_functiontrainer.run(data, max_epochs=100)

用户自定义 train_step 函数,返回啥无所谓,都会直接存储到 trainer.state.output中,后续自己可以针对性处理,这也体现去其灵活的地方了。

def update(engine, batch):    x, y = batch    y_pred = model(inputs)    loss = loss_fn(y_pred, y)    optimizer.zero_grad()    loss.backward()    optimizer.step()    return {'loss': loss.item(),            'y_pred': y_pred,            'y': y}            trainer = Engine(update)
@trainer.on(Events.EPOCH_COMPLETED)def print_loss(engine): epoch = engine.state.epoch loss = engine.state.output['loss'] print (f'Epoch {epoch}: train_loss = {loss}') accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

同时允许用户手动设置 max_epoch 和 epoch_length,epoch_length 的应用场景可以是  debug 阶段,数据集过大,可以设置 epoch_length 来取其中一小部分,也可用于 dataset 是无限长的场景。

trainer.run(data, max_epochs=100, epoch_length=200)

如果是 GAN 这种复杂场景,支持也非常容易:

model_1 = ...model_2 = ...# ...optimizer_1 = ...optimizer_2 = ...# ...criterion_1 = ...criterion_2 = ...# ...def train_step(trainer, batch):
data_1 = batch["data_1"] data_2 = batch["data_2"] # ...
model_1.train() optimizer_1.zero_grad() loss_1 = forward_pass(data_1, model_1, criterion_1) loss_1.backward() optimizer_1.step() # ...
model_2.train() optimizer_2.zero_grad() loss_2 = forward_pass(data_2, model_2, criterion_2) loss_2.backward() optimizer_2.step() # ...
# User can return any type of structure. return { "loss_1": loss_1, "loss_2": loss_2, # ... }trainer = Engine(train_step)trainer.run(data, max_epochs=100)

如果对 MMCV 比较了解,你可以认为一个 Engine 就是对应一个 Runner。

4.2 State

State 对象比较好理解,专门用于存储训练中所需的所有状态,实现训练过程和训练状态分离,便于管理。默认情况下有如下状态:

def __init__(self, **kwargs: Any) -> None:    self.iteration = 0    self.epoch = 0    self.epoch_length = None  # type: Optional[int]    self.max_epochs = None  # type: Optional[int]    self.max_iters = None  # type: Optional[int]    self.output = None  # type: Optional[int]    self.batch = None  # type: Optional[int]    self.metrics = {}  # type: Dict[str, Any]    self.dataloader = None  # type: Optional[Union[DataLoader, Iterable[Any]]]    self.seed = None  # type: Optional[int]    self.times = {        Events.EPOCH_COMPLETED.name: None,        Events.COMPLETED.name: None,    }  # type: Dict[str, Optional[float]]
for k, v in kwargs.items(): setattr(self, k, v)

如果你自定义的事件中也有状态要保存,也可以通过 event_to_attr 实现。self.output 保存的就是 update 函数返回值。

4.3 Event

Event 和 Handler 是 Ignite 的核心,要掌握这个框架就必须理解这两个对象。Event 用于记录事件的触发时机,例如每个 epoch 后,每隔 2 个 epoch等等,Handler 是在事件触发后具体的执行器。


因为事件也分成很多类型,故作者也进行了区分:


  1. Events,这个是最基本的事件记录器,典型的是 STARTED、EPOCH_STARTED、ITERATION_STARTED、ITERATION_COMPLETED 等

  2. EventsList,这个是并集操作事件记录器,用于将多个事件堆叠

  3. CallableEventWithFilter,这个是基类,用于提供基于过滤规则触发的事件


其触发的核心伪代码为:

fire_event(Events.STARTED)while epoch < max_epochs:    fire_event(Events.EPOCH_STARTED)    # run once on data    for batch in data:        fire_event(Events.ITERATION_STARTED)
output = process_function(batch)
fire_event(Events.ITERATION_COMPLETED) fire_event(Events.EPOCH_COMPLETED)fire_event(Events.COMPLETED)

虽然有三种类型,但是对用户而言不要操心,因为内部会基于事件类型来确定应该用哪个。

# epoch 完成事件-对应 Events@engine.on(Events.EPOCH_COMPLETED)
# 组合事件-对应 EventsListevents = Events.STARTED | Events.COMPLETED@engine.on(events)
# 具体过滤规则的 Events-对应类型为 CallableEventWithFilter@engine.on(Events.ITERATION_COMPLETED(every=log_interval))


4.4 Handler


任何一个处理函数都必然要和对应的 Event 事件绑定,不然不知道何时触发。当然用户也可以自定义事件。


Handler 对应一个具体处理事件的 API,其可以是一个函数,可以是一个类方法,可以通过 on 装饰器或者 add_evevnt_hander 注册到引擎中,也可以自身的通过 attach 接口接入 Engine 中。Handler非常类似 Hook,但是这里更加宽泛,不像 Hook 必须是指定的方法名和固定的输入参数,其可以随意设置


下面是一个最简单的 Handler 函数 log_training_loss

# 每隔 log_interval 且是迭代完成后触发,打印训练 loss@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))def log_training_loss(engine):    print(f"Epoch[{engine.state.epoch}], Iter[{engine.state.iteration}] Loss: {engine.state.output:.2f}")

这个 Handler 通过装饰器 on 函数注入到 trainer 中。

# 保存模型val_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint)

这个 Handler 是通过 Engine 自身的 add_event_handler 函数注入到 Engine 中。

# tb_logger 通过 attach 方法注入到 trainer 中tb_logger = TensorboardLogger(log_dir="tb-logger")
tb_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training", output_transform=lambda loss: {"batch_loss": loss},)

engine.add_event_handler(name, log_handler, self, name) 方法的。这个 Handler 是通过类本身的 attach 方法注入到 Engine 中,实际上 attach 方法内部也是调用了


除了上面这些比较简单的 Handler,作者还实现了各种各样的 Handler,都可以通过 on、add_event_handler 或者 attach 方式注入到 Engine 中,这种设计解耦性很好,容易维护和扩展。


下面是一个典型的带有 EarlyStopping 和 reduce_lr_plateau handler 的训练流程图:


5 总结

整体来说,Ignite 属于一个小而精的框架,代码量非常少,实现优雅,学习起来非常顺畅。总之其优点可以归纳为:


  1. 代码简洁优雅,容易理解

  2. 设计了一套独特的扩展模式,解耦合性非常好


但是从目前来看缺点也比较明显:


  1. Engine 做的事情过少,导致很大一部分功能都需要自定义 Handler 实现,自身要维护的代码比较多

  2. 整体功能过弱,暂时还无法和 Pytorch-Iighting 这种量级的功能相比,社区活跃度也相差很大,后面可以对 Pytorch-Iighting  这种大型复杂的框架进行整体性分析


写这篇文章的目的正如题目所言,主要是开阔下大家的视野,在设计自身代码的时候可以参考人家开源的高质量代码。当然如果你是重头写一个训练任务,那么尝试 Ignite 也是一种不错的选择。由于内容较多,时间匆忙,如果有不对的地方,可以联系我。


最后,如果你觉得本文对你有帮助,请给 MMCV 点赞


https://github.com/open-mmlab/mmcv

 

如果有任何疑问,可以直接知乎联系,知乎账号:深度眸



推荐阅读


超实用半监督目标检测 Soft Teacher 及 MMDetection 最强代码实践


浏览 21
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报