ltp.train.optim 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
import inspect
from torch import optim
from ltp.core import Registrable


[文档]class Optimizer(optim.Optimizer, metaclass=Registrable): @classmethod def from_extra(cls, extra: dict, subcls=None): if subcls is None: return {'params': filter(lambda p: p.requires_grad, extra['model'].parameters())} sig = inspect.signature(subcls) params = sig.parameters.keys() if 'params' in params: return {'params': filter(lambda p: p.requires_grad, extra['model'].parameters())} elif 'named_params' in params: named_params = [] for (n, p) in extra['model'].named_parameters(): if p.requires_grad: named_params.append((n, p)) return {'named_params': named_params} else: return {}
Optimizer.weak_register("Adadelta", optim.Adadelta) Optimizer.weak_register("Adagrad", optim.Adagrad) Optimizer.weak_register("Adam", optim.Adam) Optimizer.weak_register("AdamW", optim.AdamW) Optimizer.weak_register("SparseAdam", optim.SparseAdam) Optimizer.weak_register("Adamax", optim.Adamax) Optimizer.weak_register("ASGD", optim.ASGD) Optimizer.weak_register("SGD", optim.SGD) Optimizer.weak_register("Rprop", optim.Rprop) Optimizer.weak_register("RMSprop", optim.RMSprop) Optimizer.weak_register("LBFGS", optim.LBFGS) try: import torch_optimizer Optimizer.weak_register("AccSGD", torch_optimizer.AccSGD) Optimizer.weak_register("AdaBound", torch_optimizer.AdaBound) Optimizer.weak_register("AdaMod", torch_optimizer.AdaMod) Optimizer.weak_register("DiffGrad", torch_optimizer.DiffGrad) Optimizer.weak_register("Lamb", torch_optimizer.Lamb) Optimizer.weak_register("NovoGrad", torch_optimizer.NovoGrad) Optimizer.weak_register("PID", torch_optimizer.PID) Optimizer.weak_register("QHM", torch_optimizer.QHM) Optimizer.weak_register("RAdam", torch_optimizer.RAdam) Optimizer.weak_register("SGDW", torch_optimizer.SGDW) Optimizer.weak_register("Yogi", torch_optimizer.Yogi) except Exception as e: pass from .pretrained_optim import PretrainedOptim, BertAdamW from .task_optim import BertAdamW4CRF __all__ = ['Optimizer', 'PretrainedOptim', 'BertAdamW', 'BertAdamW4CRF']