ltp.eval.metrics.accuracy 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>

import torch
from . import Metric


[文档]class Accuracy(Metric): def __init__(self, pad_value: int = None): """ 用于词性标注 :param pad_value: 被忽略的目标值 """ super(Accuracy, self).__init__(acc=0.) self._total = 0 self._pad_value = pad_value self._total_correct = 0
[文档] def step(self, y_pred: torch.Tensor, y: torch.Tensor): y_pred = torch.argmax(y_pred, dim=-1) if y.size() != y_pred.size(): raise TypeError("y and y_pred should have the same shape") correct = torch.eq(y_pred, y) mask = torch.ones_like(y, dtype=torch.bool) if self._pad_value is None else y.ne(self._pad_value) self._total_correct += torch.sum(correct[mask]).item() self._total += mask.sum().item()
[文档] def compute(self): return {'acc': (self._total_correct / self._total) if self._total != 0 else 0}
[文档] def clear(self): self._total_correct = 0 self._total = 0