ltp.data.fields.graph 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>

from collections import Counter, OrderedDict
from itertools import chain
from typing import List

import torch

from torchtext.vocab import Vocab
from ltp.const import PAD
from . import Field
from ltp.data.dataset import Dataset


[文档]class GraphField(Field, alias='graph'): vocab_cls = Vocab def __init__(self, name, edge_spliter, pad: str = PAD, tag_spliter=None, use_vocab=True, **kwargs): """Graph Field Args: name(str): field name edge_spliter(str): 对边进行拆分的符号 pad(str): 没有 tag 的地方使用 pad 填充 tag_spliter(str): 对 tag 进行拆分的符号 use_vocab(bool): 是否使用词典 **kwargs: """ super().__init__(name, **kwargs) self.edge_spliter = edge_spliter self.tag_spliter = tag_spliter self.use_vocab = (self.tag_spliter is not None) and use_vocab self.pad = pad if self.pad: self.pad_bias = True else: self.pad_bias = False def preprocess(self, arcs: List[str]): node_size = len(arcs) res = [arcs_info.split(self.edge_spliter) for arcs_info in arcs] if self.use_vocab: res = [[(index + 1, *arc.split(self.tag_spliter)) for arc in arcs] for index, arcs in enumerate(res)] res = [(int(src), int(target), label) for src, target, label in chain(*res)] else: res = [[(index + 1, int(arc)) for arc in arcs] for index, arcs in enumerate(res)] return node_size + 1, res def build_vocab(self, *args, **kwargs): counter = Counter() sources = [] for arg in args: if isinstance(arg, Dataset): sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self] else: sources.append(arg) for data in sources: for x in data: tags = [arc[-1] for arc in x[1]] try: counter.update(tags) except TypeError: print("error") counter.update(chain.from_iterable(tags)) specials = list( OrderedDict.fromkeys( tok for tok in [self.pad] + kwargs.pop('specials', []) if tok is not None) ) self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) def process(self, batch, device=None, *args, **kwargs): length, batched_edges = zip(*batch) max_edge_size = max(length) edge_tensor = torch.zeros(size=(len(length), max_edge_size, max_edge_size)) tag_tensor = torch.zeros(size=(len(length), max_edge_size, max_edge_size), dtype=torch.long) for index, edges in enumerate(batched_edges): for edge in edges: edge_tensor[index, edge[0], edge[1]] = 1 tag_tensor[index, edge[0], edge[1]] = self.vocab[edge[2]] return edge_tensor.to(device), tag_tensor.to(device) - 1