#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
import signal
import fire
from ltp.exe import Config, Executor
[文档]class Command(object):
"""
命令行交互指令
"""
config: Config
executor: Executor
def __init__(self, config, device=None):
self.config = Config(config, device)
if self.config.torch_seed is not None:
self.setup_seed(self.config.torch_seed)
self.executor = Executor(config=self.config)
signal.signal(signal.SIGINT, self.__graceful_exit)
def setup_seed(self, seed: int):
import torch
import numpy as np
import random
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
[文档] def train(self, epoch: int = None):
"""
进行训练操作
:param epoch: 训练的轮数,默认10轮或者从配置文件中指定
"""
self.executor.train(epoch)
[文档] def eval(self, file: str = None, checkpoint: str = None, task: str = None):
"""
进行 Evaluation 操作
:param task: 任务,默认为default
:param checkpoint: 使用的 checkpoint,默认是采用Best ckpt(需要启用Checkpoint Mannger best plugin)
:param file: 要进行 evaluation 的文件,默认是test数据集
"""
self.executor.evaluate(file, checkpoint, task=task)
[文档] def predict(self, input: str, output: str, checkpoint: str = None, task: str = 'default'):
"""
进行预测操作
:param task: 任务,默认为default
:param input: 输入文件
:param output: 输出文件
:param checkpoint: 使用的 checkpoint,默认是采用Best ckpt(需要启用Checkpoint Mannger best plugin)
"""
self.executor.predict(input, output, checkpoint, task=task)
[文档] def deploy(self, path: str = 'deploy.model', vocab: str = None):
"""
将模型中不需要的部分都清理掉,仅仅保留需要预测的部分
:param path: 最终保存的模型的路径
:param vocab: 词典名字,从 huggingface 加载
"""
self.executor.deploy(path, vocab)
[文档] def test(self):
"""
测试配置文件是否正确,目前只对Dataset进行验证。
"""
self.executor.test()
def __graceful_exit(self, signum, frame):
print("Sig %s caught. Graceful exit has been called. Currently running epoch will be finished." % signum)
self.executor.stop_condition = lambda state: True
if __name__ == '__main__':
fire.Fire(Command)