ltp.data.fields.biaffine 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>

from collections import Counter, OrderedDict
from itertools import chain
from typing import List

import torch

from torchtext.vocab import Vocab
from ltp.data.dataset import Dataset

from . import Field


[文档]class BiaffineField(Field, alias='biaffine'): """ Biaffine 域 Args: name: label name use_vocab: 是否使用词典 pad: 无 label 的 label 位置使用 pad 填充 labels: 可以给定 label 而不是通过统计数据得到 """ vocab_cls = Vocab def __init__(self, name, use_vocab=True, pad='O', labels=None, **kwargs): super().__init__(name, **kwargs) self.pad = pad self.use_vocab = use_vocab if labels: counter = Counter() counter.update(labels) specials = list( OrderedDict.fromkeys( tok for tok in [self.pad] + kwargs.pop('specials', []) if tok is not None) ) self.vocab = self.vocab_cls(counter, specials=specials) def preprocess(self, inputs: List[List[str]]): heads, *labels = inputs sentences_length = len(heads) heads = [i for i, head in enumerate(heads) if head == 'Y'] target_labels = [] for predicate, label in zip(heads, labels): label = [(predicate, index, label) for index, label in enumerate(label)] target_labels.extend(label) return sentences_length, heads, target_labels def build_vocab(self, *args, **kwargs): counter = Counter() sources = [] for arg in args: if isinstance(arg, Dataset): sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self] else: sources.append(arg) for data in sources: for x in data: tags = [label[-1] for label in x[-1]] try: counter.update(tags) except TypeError: print("error") counter.update(chain.from_iterable(tags)) specials = list( OrderedDict.fromkeys( tok for tok in [self.pad] + kwargs.pop('specials', []) if tok is not None) ) self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) def process(self, batch, device=None, *args, **kwargs): sentence_lengths, heads, labels = zip(*batch) label_set = set() batch_size = len(batch) max_sentences_length = max(sentence_lengths) label_tensor = torch.zeros([batch_size] + [max_sentences_length] * 2, dtype=torch.long) # (sent_idx, predicate, word_index) => span_label for sent_idx, sent_target_labels in enumerate(labels): for predicate, index, label in sent_target_labels: label_index = self.vocab.stoi[label] if label_index == 0: continue label_tensor[sent_idx, predicate, index] = label_index label_set.add((sent_idx, predicate, index, label_index)) return label_tensor.to(device), label_set