ltp.predict.segment 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
import torch
from . import Predictor
from ltp.utils.seqeval import get_entities
from transformers import PreTrainedTokenizer


[文档]class Segment(Predictor, alias='segment'): def __init__(self, fields, text_field='text'): super().__init__(fields) self.text_field = text_field tokenizer = self.fields[text_field].tokenizer tokenizer: PreTrainedTokenizer self.id2text = tokenizer self.id2label = ['B-W', 'I-W'] def convert_idx_to_name(self, y, array_len, id2label): return [[id2label[idx] for idx in row[:row_len]] for row, row_len in zip(y, array_len)]
[文档] def predict(self, inputs, pred): input_text = getattr(inputs, self.text_field) target_len = input_text['text_length'].cpu().detach().numpy().tolist() pred = torch.argmax(pred, dim=-1).cpu().detach().numpy() text = input_text['input_ids'].cpu().detach().numpy() pred = self.convert_idx_to_name(pred, target_len, id2label=self.id2label) res = [] for text, pred_single, length in zip(text, pred, target_len): text = [self.id2text.convert_ids_to_tokens(int(char)) for char in text[1:1 + length]] entities = get_entities(pred_single) sentence = ["".join(text[entity[1]:entity[2] + 1]) for entity in entities] res.append(sentence) return res