ltp.train.stop_condition.early_stopping 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
from typing import Union, List

default_float_rounding = 4


[文档]class EarlyStopping(object): """ 早停止 """ metric: Union[str, List[str]] patience: int float_rounding: int def __init__(self, patience: int = None, metric: Union[str, List[str]] = None, float_rounding: int = None): if patience is None: patience = 5 if patience < 1: raise ValueError("Argument patience should be positive integer.") self.counter = 0 self.best_score = None self.metric = 'loss' if metric is None else metric self.patience = patience self.float_rounding = default_float_rounding if float_rounding is None else float_rounding def better(self, metric): return round(metric, self.float_rounding) >= round(self.best_score, self.float_rounding) def __call__(self, state): if isinstance(self.metric, list): metric = sum([state.get(metric_name) for metric_name in self.metric]) else: metric = state.get(self.metric) if self.best_score is None: self.best_score = metric elif self.better(metric): self.best_score = metric self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: return True return False