ltp.models.model 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
from typing import Union

from torch.nn import Module, Linear
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel
from ltp.core import Registrable
from ltp import modules

"""
说明:
    text: 指代类似生文本,指的的文本,也就是字(Char Based)级别
    word: 指代词语,通常是词语,比如 word_index 就是词语开头的下标
"""


[文档]class Model(Module, metaclass=Registrable): """ 模型基础类,从本类继承的模型自动注册 """ pretrained: PreTrainedModel def create_pretrained(self, pretrained: str = None, config: Union[str, dict, PretrainedConfig] = None, freeze: bool = False): if isinstance(pretrained, str): # 认为是 PATH 或 Huggingface 社区模型 pretrained = AutoModel.from_pretrained(pretrained) elif isinstance(config, PretrainedConfig): pretrained = AutoModel.from_config(config) elif isinstance(config, str) or isinstance(config, dict): # 认为是 PATH 或 Huggingface 社区模型 config = AutoConfig.for_model(**config) pretrained = AutoModel.from_config(config) else: raise NotImplementedError() if freeze: for param in pretrained.parameters(): param.requires_grad = False return pretrained
[文档] def create_decoder(self, input_size, label_num, dropout=0.1, **kwargs): """ 封装了各种解码器 """ decoder_type = kwargs.pop('decoder', 'Linear') if decoder_type: decoder_args = kwargs.pop(decoder_type, {}) decoder_args.setdefault('dropout', dropout) decoder_class = modules.Module.by_name(decoder_type) return decoder_class(input_size, label_num, **decoder_args) else: raise NotImplementedError()
def __getitem__(self, item): self.task = item return self
[文档] def forward(self, *args, **kwargs): """前向传播函数 """ raise NotImplementedError()