ltp.eval.metrics.tree 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
from typing import Tuple

import torch
from . import Metric
from ltp.utils import length_to_mask


[文档]class TreeMetrics(Metric, alias='tree'): def __init__(self, eisner=False): """ Tree Metric(LAS, UAS) :param pad_value: 被忽略的目标值 """ super(TreeMetrics, self).__init__(LAS=0., UAS=0.) self._eisner = eisner self._head_true = 0 self._label_true = 0 self._union_true = 0 self._all = 0 @property def UAS(self): return (self._head_true / self._all) if self._all != 0 else 0 @property def LAS(self): return (self._union_true / self._all) if self._all != 0 else 0
[文档] def step(self, y_pred: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor]): arc_pred, label_pred, seq_len = y_pred mask = length_to_mask(seq_len + 1) mask[:, 0] = False if self._eisner: from ltp.utils import eisner arc_pred = eisner(arc_pred, mask) else: arc_pred = torch.argmax(arc_pred, dim=-1) label_pred = torch.argmax(label_pred, dim=-1) arc_real, label_real = y label_pred = label_pred.gather(-1, arc_pred.unsqueeze(-1)).squeeze(-1) mask = mask.narrow(-1, 1, mask.size(1) - 1) arc_pred = arc_pred.narrow(-1, 1, arc_pred.size(1) - 1) label_pred = label_pred.narrow(-1, 1, label_pred.size(1) - 1) head_true = (arc_pred == arc_real)[mask] label_true = (label_pred == label_real)[mask] self._head_true += torch.sum(head_true).item() self._label_true += torch.sum(label_true).item() self._union_true += torch.sum(label_true[head_true]).item() self._all += torch.sum(mask).item()
[文档] def clear(self): self._head_true = 0 self._label_true = 0 self._union_true = 0 self._all = 0
[文档] def compute(self): return {'LAS': self.LAS, 'UAS': self.UAS}