ltp.train.checkpoint 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
import sys
import os, regex, logging
from typing import List, Dict
from collections import OrderedDict

import torch

from ltp.core import Registrable
from ltp.const import checkpoint

default_save_diretory = './checkpoint'
default_best_filename = 'best.pt.tar'

logger = logging.getLogger(__name__)


[文档]class CheckpointPlugin(metaclass=Registrable): """ Checkpoint :param directory: 存放目录 :param filename: 检查点文件名 """ def __init__(self, directory: str = None, filename: str = None): if directory is None: directory = default_save_diretory if filename is None: filename = checkpoint self.directory = directory self.filename = filename os.makedirs(self.directory, exist_ok=True) def save(self, task, state): raise NotImplementedError def load(self, task, state): raise NotImplementedError def save_checkpoint(self, task, state, filename: str): filepath = os.path.join(self.directory, filename) ckpt_dict = { 'model_state_dict': task.model.state_dict(), 'optimizer_state_dict': task.optimizer.state_dict(), 'executor_state': state, 'check_model_class': str(task.model.__class__), 'check_optimizer_class': str(task.optimizer.__class__), # 'check_ldtp_version': __version__ # TODO } if task.scheduler: ckpt_dict['scheduler_state_dict'] = task.scheduler.state_dict() ckpt_dict['check_scheduler_class'] = str(task.scheduler.__class__) for name, field in task.fields: if not hasattr(field, 'use_vocab') or not field.use_vocab: continue if hasattr(field, 'vocab') and hasattr(field.vocab, '__getstate__'): ckpt_dict[name] = field.vocab.__getstate__() torch.save(ckpt_dict, filepath) logger.debug(f"Have saved checkpoint to {filepath}!!!") return filepath def load_checkpoint(self, task, state, filename): filepath = os.path.join(self.directory, filename) model = task.model optimizer = task.optimizer if not os.path.isfile(filepath): logger.warning(f"Restored From checkpoint {filepath} failed!!!") return checkpoint = torch.load(filepath, map_location=task.device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if state: state.load_state_dict(checkpoint['executor_state']) if task.scheduler: task.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) for name, field in task.fields: if not hasattr(field, 'use_vocab') or not field.use_vocab: continue field.build_vocab() if hasattr(field, 'vocab') and hasattr(field.vocab, '__setstate__'): field.vocab.__setstate__(checkpoint[name]) logger.info(f"Have restored from checkpoint({filepath})")
class CheckpointManager(object): plugins: Dict[str, CheckpointPlugin] def __init__(self, directory=None, filename=None, plugins: List = None, **kwargs): """ Checkpoint 管理器 :param directory: Ckpt 保存路径 :param filename: 保存的文件名,通常会被转成不同的形式 :param plugins: List [best, restore] :param best: dict 传给 best 的额外参数,使用 best 插件时,metric 必选 """ if directory is None: directory = default_save_diretory self.directory = directory self.filename = filename self.plugins = OrderedDict() if plugins is not None: for plugin_name in plugins: if not hasattr(self, 'mode'): self.mode = plugin_name plugin_class = CheckpointPlugin.by_name(plugin_name) plugin_config = kwargs.get(plugin_name, {}) plugin_config.setdefault('filename', filename) plugin_config.setdefault('directory', directory) self.plugins[plugin_name] = plugin_class(**plugin_config) elif kwargs is None or not len(kwargs): self.mode = 'restore' self.plugins['restore'] = RestoreCheckpointPlugin(directory=directory, filename=filename) else: for plugin_name, plugin_config in kwargs.items(): if not hasattr(self, 'mode'): self.mode = plugin_name plugin_class = CheckpointPlugin.by_name(plugin_name) plugin_config.setdefault('filename', filename) plugin_config.setdefault('directory', directory) self.plugins[plugin_name] = plugin_class(**plugin_config) def __getitem__(self, item): if item in self.plugins: self.mode = item else: self.mode = LoadCheckpointPlugin(directory=self.directory, filename=item) return self def load(self, task, state): # 加载模型 if isinstance(self.mode, str): self.plugins[self.mode].load(task, state) elif isinstance(self.mode, LoadCheckpointPlugin): self.mode.load(task, state) def save(self, task, state): for plugin in self.plugins.values(): plugin.save(task, state)
[文档]class LoadCheckpointPlugin(CheckpointPlugin, alias='_load'): def load(self, task, state): self.load_checkpoint(task, state, self.filename) def save(self, task, state): pass
[文档]class BestCheckpointPlugin(CheckpointPlugin, alias='best'): def __init__(self, metric, directory: str = None, filename: str = None): if filename is None: filename = default_best_filename super(BestCheckpointPlugin, self).__init__(directory=directory, filename=filename) self.metric_name = metric self.best_metric = None c = self.filename.count('.') base, *ext = self.filename.rsplit('.', c) self.format = f"{base}_%.4f." + '.'.join(ext) # base_{metric}.ext self.regex = regex.compile(f"{base}_(\\d+\\.\\d+)\\." + '\\.'.join(ext)) self.last_filepath = None def load(self, task, state): if not os.path.isdir(self.directory): return filenames = os.listdir(self.directory) best_metric = None best_checkpoint = None for filename in filenames: if os.path.isdir(filename): continue match = regex.fullmatch(self.regex, filename) if match is None: continue found_metric = float(match.group(1)) if best_metric is None or found_metric > best_metric: best_metric = found_metric best_checkpoint = filename if best_checkpoint is not None: self.load_checkpoint(task, state, best_checkpoint) def save(self, task, state): try: if isinstance(self.metric_name, list): metric = sum([state.get(metric_name) for metric_name in self.metric_name]) / len(self.metric_name) else: metric = state.get(self.metric_name) if self.best_metric is None or metric > self.best_metric: self.best_metric = metric filename = self.format % self.best_metric if self.last_filepath is not None and os.path.exists(self.last_filepath): os.remove(self.last_filepath) self.last_filepath = self.save_checkpoint(task, state, filename) except Exception as e: print(f"No metric {self.metric_name}")
[文档]class RestoreCheckpointPlugin(CheckpointPlugin, alias="restore"): def __init__(self, *args, **kwargs): super(RestoreCheckpointPlugin, self).__init__(*args, **kwargs) c = self.filename.count('.') base, *ext = self.filename.rsplit('.', c) self.format = f"{base}_%d." + '.'.join(ext) self.regex = regex.compile(f"{base}_(\\d+)\\." + '\\.'.join(ext)) def load(self, task, state): last_iter = None last_checkpoint = None if not os.path.isdir(self.directory): return filenames = os.listdir(self.directory) for filename in filenames: if os.path.isdir(filename): continue match = regex.fullmatch(self.regex, filename) if match is None: continue found_epoch = int(match.group(1)) if last_iter is None or found_epoch > last_iter: last_iter = found_epoch last_checkpoint = filename if last_checkpoint is not None: self.load_checkpoint(task, state, last_checkpoint) def save(self, task, state): filename = self.format % state.global_step self.save_checkpoint(task, state, filename)