
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <>

from collections import Counter, OrderedDict
from itertools import chain
from typing import Union

import torch

from torchtext.vocab import Vocab
from ltp.const import PAD
from . import Field
from import Dataset

def dtype_to_attr(dtype):
    # convert torch.dtype to dtype string id
    # e.g. torch.int32 -> "int32"
    # used for serialization
    _, dtype = str(dtype).split('.')
    return dtype

[文档]class SequenceField(Field, alias='sequence'): """ 序列 Field,通常为target :param name: Field name :param bos: Begin Of Sentence,默认为空 :param eos: End Of Sentence,默认为空 :param unk: Unknown Tag, 默认为空 :param pad: 默认为 [PAD] 或 -1 :param dtype: torch.dtype,可以使用字符串 :param pad_bias: 做值域变换,将pad的值变到1,默认开启 :param preprocessing: 预处理 :param postprocessing: 后处理 :param max_length: 是否padding到最大长度,None即为不做特殊处理,默认为None :param include_lengths: 是否返回 length,默认为False :param use_vocab: 是否使用词表,默认为True :param is_target: 是否为target,默认为True """ vocab_cls = Vocab dtypes = { torch.float32: float, torch.float: float, torch.float64: float, torch.double: float, torch.float16: float, torch.half: float, torch.uint8: int, torch.int8: int, torch.int16: int, torch.short: int, torch.int32: int, int, torch.int64: int, torch.long: int, } ignore = ['dtype'] def __init__(self, name, bos: Union[str, int] = None, eos: Union[str, int] = None, unk: Union[str, int] = None, pad: Union[str, int] = None, dtype=torch.long, pad_bias=True, preprocessing=None, postprocessing=None, max_length: int = None, include_lengths=False, labels=None, use_vocab=True, is_target: bool = True, **kwargs): super(SequenceField, self).__init__(name, preprocessing, postprocessing, is_target) self.unk = unk if (isinstance(unk, str) and use_vocab) or (isinstance(unk, int) and not use_vocab) else None self.bos = bos if (isinstance(bos, str) and use_vocab) or (isinstance(bos, int) and not use_vocab) else None self.eos = eos if (isinstance(eos, str) and use_vocab) or (isinstance(eos, int) and not use_vocab) else None if use_vocab: self.pad = pad if isinstance(pad, str) else PAD elif not use_vocab: self.pad = pad if isinstance(pad, int) else -1 if isinstance(dtype, str): self.dtype = getattr(torch, dtype) else: self.dtype = dtype self.use_vocab = use_vocab self.max_length = max_length self.include_lengths = include_lengths self.pad_bias = pad_bias if labels: counter = Counter() counter.update(labels) specials = list( OrderedDict.fromkeys( tok for tok in [self.unk, self.pad, self.bos, self.eos] + kwargs.pop('specials', []) if tok is not None) ) self.vocab = self.vocab_cls(counter, specials=specials) def build_vocab(self, *args, **kwargs): if hasattr(self, 'vocab'): return counter = Counter() sources = [] for arg in args: if isinstance(arg, Dataset): sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self] else: sources.append(arg) for data in sources: for x in data: try: counter.update(x) except TypeError: counter.update(chain.from_iterable(x)) specials = list( OrderedDict.fromkeys( tok for tok in [self.unk, self.pad, self.bos, self.eos] + kwargs.pop('specials', []) if tok is not None) ) self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) def __setstate__(self, state): state['dtype'] = getattr(torch, state['dtype']) return super(SequenceField, self).__setstate__(state) def __getstate__(self): attrs = super(SequenceField, self).__getstate__() attrs['dtype'] = dtype_to_attr(self.dtype) return attrs def pad_batch(self, minibatch: list): minibatch = list(minibatch) max_len = max(len(x) for x in minibatch) if self.max_length is not None: max_len = min(max_len, self.max_length + (self.bos, self.eos).count(None) - 2) padded, lengths = [], [] for x in minibatch: padded.append( ([] if self.bos is None else [self.bos]) + list(x[:max_len]) + ([] if self.eos is None else [self.eos]) + [self.pad] * max(0, max_len - len(x[:max_len])) ) lengths.append(len(padded[-1]) - max(0, max_len - len(x[:max_len]))) if self.include_lengths: return (padded, lengths) return padded def numericalize(self, arr, device=None): if self.include_lengths and not isinstance(arr, tuple): raise ValueError("Field has include_lengths set to True, but " "input data is not a tuple of " "(data batch, batch lengths).") lengths = None if isinstance(arr, tuple): arr, lengths = arr if self.use_vocab: arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] if self.postprocessing is not None: arr = self.postprocessing(arr, self.vocab) else: if self.dtype not in self.dtypes: raise ValueError( f"Specified Field dtype {self.dtype} can not be used with " "use_vocab=False because we do not know how to numericalize it. " ) numericalization_func = self.dtypes[self.dtype] arr = [[numericalization_func(x) for x in ex] for ex in arr] if self.postprocessing is not None: arr = self.postprocessing(arr, None) var = torch.tensor(arr, dtype=self.dtype, device=device) var = var.contiguous() if self.include_lengths: lengths = torch.tensor(lengths, dtype=self.dtype, device=device) return var, lengths return var def preprocess(self, x): if self.preprocessing is not None: x = self.preprocessing(x) return x def process(self, batch, device=None): padded = self.pad_batch(batch) tensor = self.numericalize(padded, device=device) if not self.pad_bias: return tensor if isinstance(tensor, torch.Tensor): return tensor - 1 else: tensor, length = tensor return tensor - 1, length