命令

Command

class Command(config, device=None)[源代码]

基类:object

命令行交互指令

train(epoch=None)[源代码]

进行训练操作

参数

epoch – 训练的轮数,默认10轮或者从配置文件中指定

eval(file=None, checkpoint=None, task=None)[源代码]

进行 Evaluation 操作

参数
  • task – 任务,默认为default

  • checkpoint – 使用的 checkpoint,默认是采用Best ckpt(需要启用Checkpoint Mannger best plugin)

  • file – 要进行 evaluation 的文件,默认是test数据集

predict(input, output, checkpoint=None, task='default')[源代码]

进行预测操作

参数
  • task – 任务,默认为default

  • input – 输入文件

  • output – 输出文件

  • checkpoint – 使用的 checkpoint,默认是采用Best ckpt(需要启用Checkpoint Mannger best plugin)

deploy(path='deploy.model', vocab=None)[源代码]

将模型中不需要的部分都清理掉,仅仅保留需要预测的部分 :param path: 最终保存的模型的路径 :param vocab: 词典名字,从 huggingface 加载

test()[源代码]

测试配置文件是否正确,目前只对Dataset进行验证。

Executer

class Executor(config)[源代码]

基类:object

实际训练器的一个简单包装

train_wrapper(dataloaders, epochs=30, tau=1.0)[源代码]

通过给定的 data loader 进行训练,训练会进行到epochs或者 stop condition = True

参数
  • tau – 放大指数

  • dataloaders – PyTorch DataLoader

  • epochs – 训练的最大轮数

evaluate_(data_loader, metrics, task)[源代码]

评估一个模型 :param task: 任务信息 :param data_loader: PyTorch DataLoader :param metrics: 进行评估的 Metrics :return: 计算得到的各项指标

predict_(dataloader, outputs, task)[源代码]

进行预测操作

参数
  • dataloader – PyTorch DataLoader

  • outputs – 输出的文件名