ltp.models.seq_tag_model 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
from typing import Tuple, Union, Dict, Optional

import torch
from transformers import AutoConfig, AutoModel, PretrainedConfig

from .model import Model
from ltp.modules import Module


[文档]class SequenceTaggingModel(Model, alias="_sequence_tagging"): """ 基本序列标注模型,封装了各种解码器,将来可能要进行解耦,如果使用crf解码器,需要换用支持 CRF 的 Trainer 使用 init_decoder 函数初始化解码器 :param decoder: [None=Linear,lan=lan decoder] 默认是简单线性层分类,目前支持 lan decoder :param hidden_size: decoder = lan 时必填, lan hidden size :param num_heads: decoder = lan 时使用, lan 多头注意力模型 heads,默认为 5 :param num_layers: decoder = lan 时使用, lan decoder 层数,默认为 3 :param lan: decoder = lan 时使用,lan decoder 其他参数 :param arc_hidden_size: decoder = biaffine 时必填 :param rel_hidden_size: decoder = biaffine 时必填 :param rel_num: decoder = biaffine 时默认为label_num,rel 数目 :param bias: decoder = Linear 时,传入Linear层 """ decoder: Module
[文档] def init_decoder(self, input_size, label_num, dropout=0.1, **kwargs): """ 基本序列标注模型,封装了各种解码器 """ self.decoder = self.create_decoder(input_size, label_num, dropout, **kwargs)
def hidden2logits(self, logits, seq_len, gold: Optional = None): return self.decoder.forward(logits, seq_len, gold) def forward_(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError("未实现")
[文档] def forward(self, text: Dict[str, torch.Tensor], gold: Optional = None): features, seq_lens = self.forward_(text) return self.hidden2logits(features, seq_lens, gold)
[文档]class SimpleTaggingModel(SequenceTaggingModel, alias="SimpleTagging"): """ 基本序列标注模型 :param pretrained: 预训练模型路径或名称,参照 huggingface/transformers :param config: 预训练模型路径或名称,参照 huggingface/transformers :param label_num: 分类标签数目 :param dropout: pretrained Embedding dropout,默认0.1 :param word_base: 是否是以词为基础的模型,如果是,输入时需要传入 word index :param decoder: [None=Linear,lan=lan decoder] 默认是简单线性层分类,目前支持 lan decoder :param hidden_size: decoder = lan 时必填, lan hidden size :param num_heads: decoder = lan 时使用, lan 多头注意力模型 heads,默认为 5 :param num_layers: decoder = lan 时使用, lan decoder 层数,,默认为 3 :param lan: decoder = lan 时使用,lan decoder 其他参数 :param arc_hidden_size: decoder = graph 时必填 :param rel_hidden_size: decoder = graph 时必填 :param rel_num: decoder = graph 时默认为label_num,rel 数目 :param bias: decoder = Linear 时,传入Linear层 """ def __init__(self, label_num, pretrained: str = None, config: Union[str, PretrainedConfig] = None, dropout=0.1, word_base=False, use_cls=False, freeze=False, **kwargs): """ 基本序列标注模型 """ super(SimpleTaggingModel, self).__init__() self.word_base = word_base self.use_cls = use_cls self.gard_enable = not freeze self.pretrained = self.create_pretrained(pretrained, config=config, freeze=freeze) config = self.pretrained.config self.emb_dropout = torch.nn.Dropout(p=dropout) self.init_decoder( input_size=config.hidden_size, label_num=label_num, dropout=dropout, **kwargs ) def forward_(self, text: Dict[str, torch.Tensor]): with torch.set_grad_enabled(self.gard_enable): pretrained_output, *_ = self.pretrained( text['input_ids'], attention_mask=text['attention_mask'], token_type_ids=text['token_type_ids'] ) # remove [CLS] [SEP] pretrained_output = torch.narrow( pretrained_output, 1, 1 - self.use_cls, pretrained_output.size(1) - 2 + self.use_cls ) pretrained_output = self.emb_dropout(pretrained_output) if self.word_base: word_idx = text['word_index'] if self.use_cls: cls_tensor = torch.zeros((word_idx.shape[0], 1), dtype=word_idx.dtype, device=word_idx.device) word_idx = torch.cat([cls_tensor, word_idx + 1], dim=-1) word_idx = word_idx.unsqueeze(-1).expand(-1, -1, pretrained_output.shape[-1]) pretrained_output = torch.gather(pretrained_output, dim=1, index=word_idx) seq_lens = text['word_length'] else: seq_lens = text['text_length'] return pretrained_output, seq_lens