ltp.predict.biaffine 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
import torch
from . import Predictor
from ltp.utils import length_to_mask


[文档]class Biaffine(Predictor, alias='biaffine'): def __init__(self, fields, text_field='text', tag_field='tag'): super().__init__(fields) self.text_field = text_field self.tag_vocab = self.fields[tag_field].vocab.itos def _convert_idx_to_name(self, y, array_len): return [[self.tag_vocab[idx] for idx in row[:row_len]] for row, row_len in zip(y, array_len)]
[文档] def predict(self, inputs, pred): srl_output, srl_length = pred mask = length_to_mask(srl_length) mask = mask.unsqueeze_(-1).expand(-1, -1, mask.size(1)) mask = (mask & mask.transpose(-1, -2)).flatten(end_dim=1) index = mask[:, 0] mask = mask[index] srl_output = srl_output.flatten(end_dim=1)[index] srl_labels = torch.argmax(srl_output, dim=-1).cpu().numpy() srl_labels = self._convert_idx_to_name(srl_labels, mask.sum(dim=1)) # srl_labels_res = [] # for length in srl_length: # srl_labels_res.append([]) # curr_srl_labels, srl_labels = srl_labels[:length], srl_labels[length:] # srl_labels_res[-1].extend([get_entities(labels) for labels in curr_srl_labels]) return srl_labels