ltp.exe.executor 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
import sys
from time import time
from typing import Iterable, Union, List, Set, Dict

import numpy as np
import torch
import torch.utils.data

from tqdm import tqdm
from ltp.exe.config import Config
from ltp.train.callback import Callback, ValidationCallback
from ltp.train.stop_condition import NoStopping, EarlyStopping
from ltp.eval.metrics import Metric
from ltp.train import Trainer
from ltp.utils import cycle


[文档]class Executor(object): """ 实际训练器的一个简单包装 """ config: Config trainer: Trainer __callbacks: List progressbar_metrics: Set stop_condition: Union[EarlyStopping, NoStopping] def __init__(self, config: Config): self.config = config self.stop_condition = EarlyStopping( **self.config.early_stopping ) if self.config.early_stopping else NoStopping() self.__callbacks = [] self.progressbar_metrics = set() self.trainer = Trainer.from_params(self.config.trainer, config=self.config) self.tasks = self.config.tasks self.epoch_size = self.config.epoch_size def train(self, epoch: int = 30): epoch = self.config.epoch if epoch is None else epoch # =================================== 任务初始化 ================================= multi_task_dataset = {} for name, task in self.tasks.items(): if not task.config.dataset: continue if name != 'default': task.load() train, valid = task.train_dataset(self.config.batch_size) valid_callback = ValidationCallback(data_loader=valid, metrics=task.metrics, task=name) self.register_callback(valid_callback) multi_task_dataset[name] = train # =================================== 回调建立 =================================== for callback in self.config.callbacks: self.register_callback(Callback.from_params(callback)) # ======================= 开始训练 =============================================== self.train_wrapper(multi_task_dataset, epoch, tau=self.config.tau) # =================================== 指标评估 =================================== self.evaluate() def evaluate(self, file: str = None, checkpoint: str = None, task: str = None): print("========================= Evaluate ==========================", file=sys.stderr) self.tasks['default'].load(self.trainer.state, checkpoint) if task is None: for name, task_obj in self.tasks.items(): if not task_obj.config.dataset: continue if name != 'default': task_obj.load() test = task_obj.load_dataset(file, self.config.batch_size) self.evaluate_(test, task_obj.metrics, name) else: if task != 'default': self.tasks[task].load() test = self.tasks[task].load_dataset(file, self.config.batch_size) self.evaluate_(test, self.tasks[task].metrics, task) def predict(self, inputs: str, outputs: str, checkpoint: str = None, task: str = None): self.tasks['default'].load(self.trainer.state, checkpoint) if task != 'default': self.tasks[task].load() test = self.tasks[task].load_dataset(inputs, self.config.batch_size) self.predict_(test, outputs, task) def deploy(self, path: str = 'deploy.model', vocab: str = None): deploy_state_dict = {'vocab': vocab} for task, task_obj in self.tasks.items(): task_obj.load() if task == 'default': deploy_state_dict['model'] = task_obj.model.state_dict() deploy_state_dict['model_config'] = task_obj.config.model deploy_state_dict['pretrained_config'] = task_obj.model.pretrained.config.to_dict() for field_name, field in task_obj.fields: if field.is_target and hasattr(field, 'vocab') and hasattr(field.vocab, 'itos'): pad_bias = getattr(field, 'pad_bias', 0) deploy_state_dict[task] = field.vocab.itos[pad_bias:] torch.save(deploy_state_dict, path)
[文档] def train_wrapper(self, dataloaders: Dict[str, torch.utils.data.DataLoader], epochs: int = 30, tau: float = 1.0): """ 通过给定的 data loader 进行训练,训练会进行到epochs或者 stop condition = True :param tau: 放大指数 :param dataloaders: PyTorch DataLoader :param epochs: 训练的最大轮数 """ print("========================== Train ============================", file=sys.stderr) self.trainer.model.train() # set the module to training mode train_start = time() # cycle has a memory leak dataiters = {k: cycle(v) for k, v in dataloaders.items()} if all(hasattr(v, '__len__') for v in dataloaders.values()): dataloader_sizes = {k: len(v) for k, v in dataloaders.items()} total_size = sum(v for k, v in dataloader_sizes.items()) Z = sum(pow(v, tau) for v in dataloader_sizes.values()) tasknames, sampling_weights = zip(*((k, pow(v, tau) / Z) for k, v in dataloader_sizes.items())) else: raise NotImplementedError("Dataloader 需要实现 __len__ 方法") if self.epoch_size: total_size = self.epoch_size self.tasks['default'].build_scheduler(epochs * total_size) self.tasks['default'].restore(self.trainer.state) self.trainer.init(epochs * total_size) self.trainer.state.current_epoch += 1 while self.trainer.state.current_epoch <= epochs and not self.stop_condition(self.trainer.state): # ------------------------- EPOCH ---------------------------- self.trainer.before_train() self.train_(total_size, tasknames, sampling_weights, dataiters) self.__run_post_epoch_callbacks() self.trainer.after_train() # ------------------------- EPOCH ---------------------------- self.trainer.state.current_epoch += 1 print("train time %.2f" % (time() - train_start))
def train_(self, total_size, tasknames, sampling_weights, dataiters): ofm = 0 # out of memory with tqdm(range(total_size), desc=f'Train({self.trainer.state.current_epoch}): ') as epoch_steps: for _ in epoch_steps: try: taskname = np.random.choice(tasknames, p=sampling_weights) dataiter = dataiters[taskname] batch = next(dataiter) self.trainer.state.last_train_loss = self.trainer.train(batch, task=taskname) self.trainer.state.global_step += 1 self.__run_post_iteration_callbacks() postfix = {metric: getattr(self.trainer.state, metric) for metric in self.progressbar_metrics} postfix["loss"] = self.trainer.state.last_train_loss postfix["ofm"] = ofm epoch_steps.set_postfix(postfix) except Exception as e: detail = e.args[0] if isinstance(detail, str) and detail.startswith("CUDA out of memory"): ofm += 1 epoch_steps.set_postfix({"ofm": ofm}) continue raise e
[文档] def evaluate_(self, data_loader: torch.utils.data.DataLoader, metrics: Iterable[Metric], task: str) -> \ Iterable[Metric]: """ 评估一个模型 :param task: 任务信息 :param data_loader: PyTorch DataLoader :param metrics: 进行评估的 Metrics :return: 计算得到的各项指标 """ for metric in metrics: metric.clear() ofm = 0 # out of memory self.trainer.before_eval(task) with torch.no_grad(), tqdm(data_loader, desc=f'{task}({self.trainer.state.current_epoch}): ')as pbar: for batch in pbar: try: x, y_pred, y = self.trainer.eval(batch, task) metric_values = {} for metric in metrics: metric.step(y_pred, y) metric_values.update(metric.compute()) if ofm > 0: metric_values['ofm'] = ofm pbar.set_postfix(metric_values) except Exception as e: detail = e.args[0] if isinstance(detail, str) and detail.startswith("CUDA out of memory"): ofm += 1 continue raise e self.trainer.after_eval(task) return metrics
[文档] def predict_(self, dataloader: torch.utils.data.DataLoader, outputs: str, task): """ 进行预测操作 :param dataloader: PyTorch DataLoader :param outputs: 输出的文件名 """ self.trainer.before_predict(task) with torch.no_grad(), open(outputs, mode='w', encoding='utf8') as f, \ tqdm(dataloader, dynamic_ncols=True, desc=f'Predict: ') as pbar: for batch in pbar: result = self.trainer.predict(batch, task) for pred in result: f.writelines("\t".join(pred) + "\n") self.trainer.after_predict(task)
def test(self): for name, task in self.tasks.items(): if not task.config.dataset: continue if name != 'default': task.load() train, valid = task.train_dataset(self.config.batch_size, False) for name, field in task.fields: if hasattr(field, 'vocab') and hasattr(field.vocab, 'itos'): print(name, 'vocab size(with pad):', len(field.vocab.itos), file=sys.stderr) print(field.vocab.itos, file=sys.stderr) for train_iter in tqdm(train): pass for valid_iter in tqdm(valid): pass for test_iter in tqdm(task.load_dataset(batch_size=self.config.batch_size)): pass def add_progressbar_metric(self, name): self.progressbar_metrics.add(name) def register_callback(self, callback: Callback): callback.init(self) self.__callbacks.append(callback) def __run_post_iteration_callbacks(self): for callback in self.__callbacks: if callback.iteration is None: continue if callback.iteration != 0 and self.trainer.state.global_step % callback.iteration == 0: callback(self) def __run_post_epoch_callbacks(self): for callback in self.__callbacks: if callback.epoch is None: continue if (callback.epoch != 0 and self.trainer.state.global_step % callback.epoch == 0): callback(self)