ltp.data.fields.text 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>

from typing import Union, Dict
from itertools import chain
import torch, torch.nn.utils.rnn as rnn, numpy as np
from transformers import AutoTokenizer, PreTrainedTokenizer

from ltp.core.exceptions import DataUnsupported
from . import Field


[文档]class TextField(Field, alias='text'): """文本域 Args: tokenizer: Tokenizer name: Field Na,e return_length: 是否同时返回 length word_info: 是否返回 word index is_target: 是否为目标域 """ tokenizer_cls = AutoTokenizer def __init__(self, tokenizer: Union[str, PreTrainedTokenizer, Dict[str, str]], name='text', return_tokens=False, return_word_idn=False, return_length=True, word_info=True, is_target=False): super(TextField, self).__init__(name, is_target) self.word_info = word_info self.return_length = return_length self.return_tokens = return_tokens self.return_word_idn = return_word_idn if isinstance(tokenizer, str): self.tokenizer = self.tokenizer_cls.from_pretrained(tokenizer, use_fast=True) elif isinstance(tokenizer, PreTrainedTokenizer): self.tokenizer = tokenizer assert self.tokenizer.is_fast elif isinstance(tokenizer, Dict): tokenizer['pretrained_model_name_or_path'] = tokenizer.pop('path') tokenizer['use_fast'] = True self.tokenizer = self.tokenizer_cls.from_pretrained(**tokenizer) def preprocess(self, x): sentence = self.tokenizer.batch_encode_plus( x, add_special_tokens=False, return_attention_masks=False, return_token_type_ids=False ) subword_len = [len(subword.offsets) for subword in sentence.encodings] word_lengths = np.cumsum([0] + subword_len, dtype=np.int64) word_start, text_length = word_lengths[:-1], word_lengths[-1] word_start_idn = list(chain.from_iterable([0] + [1] * (length - 1) for length in subword_len)) if text_length > 510: raise DataUnsupported("文本过长!!") # mixed_sentence = ' '.join(x) # resplit = self.tokenizer.encode( # mixed_sentence, add_special_tokens=False, return_attention_masks=False, return_token_type_ids=False # ) # if len(resplit) != len(word_start_idn): # print("X: ", x) # print("Mixed: ", mixed_sentence) return ' '.join(x), torch.as_tensor(text_length), \ torch.as_tensor(word_start_idn), torch.as_tensor(word_start), torch.as_tensor(len(x)), def process(self, batch, device=None): sentence, text_length, word_start_idn, word_index, word_length = zip(*batch) tokenized = self.tokenizer.batch_encode_plus(list(sentence), return_tensors='pt') res = { 'input_ids': tokenized['input_ids'].to(device), 'token_type_ids': tokenized['token_type_ids'].to(device), 'attention_mask': tokenized['attention_mask'].to(device), } if self.return_length: res['text_length'] = torch.stack(text_length).to(device) if self.return_word_idn: res['word_idn'] = rnn.pad_sequence(word_start_idn, batch_first=True).to(device) if self.word_info: res['word_index'] = rnn.pad_sequence(word_index, batch_first=True).to(device) if self.return_length: res['word_length'] = torch.stack(word_length).to(device) return res
[文档]class MixedTextField(TextField): def process(self, batch, device=None): dataset, sentence, text_length, word_start_idn, word_index, word_length = zip(*batch) cls = torch.as_tensor(self.tokenizer.convert_tokens_to_ids(dataset), device=device).unsqueeze_(1) tokenized = self.tokenizer.batch_encode_plus(list(sentence), return_tensors='pt') input_ids = tokenized['input_ids'].to(device=device)[:, 1:] res = { 'input_ids': torch.cat([cls, input_ids], dim=-1), 'token_type_ids': tokenized['token_type_ids'].to(device), 'attention_mask': tokenized['attention_mask'].to(device), } if self.return_length: res['text_length'] = torch.stack(text_length).to(device) if self.return_word_idn: res['word_idn'] = rnn.pad_sequence(word_start_idn, batch_first=True).to(device) if self.word_info: res['word_index'] = rnn.pad_sequence(word_index, batch_first=True).to(device) if self.return_length: res['word_length'] = torch.stack(word_length).to(device) return res