ltp.predict.simple 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
import torch
from . import Predictor


[文档]class Simple(Predictor, alias='simple'): """ 简单预测器 :param fields: Tuple(str, Field) 数据的 Field 对象,配置文件不需要填写 :param text_field: str 提供文本以及长度的 Field,默认是第一个 Filed :param target_filed: str 提供解码词典的 Field, 默认是第一个 Target Field (提供他的词表) """ def __init__(self, fields, text_field: str = 'text', target_filed: str = None): """ 简单预测器 """ super(Simple, self).__init__(fields) self.text_field = self.fields[text_field] if target_filed is None: label_field = self.target_fields[0] else: label_field = self.fields.get(target_filed, default=self.target_fields[0]) if label_field.use_vocab: self.itos = label_field.vocab.itos[1:] def convert_idx_to_name(self, y, array_len, id2label): if id2label: return [[id2label[idx] for idx in row[:row_len]] for row, row_len in zip(y, array_len)] else: return [[idx for idx in row[:row_len]] for row, row_len in zip(y, array_len)]
[文档] def predict(self, inputs, pred): text_field = getattr(inputs, self.text_field) target_len = text_field[1] if isinstance(text_field, tuple) else text_field target_len = target_len.cpu().detach().numpy().tolist() pred = torch.argmax(pred, dim=-1).cpu().detach().numpy() pred = self.convert_idx_to_name(pred, target_len, self.itos) return pred