ltp.eval.metrics.common_loss 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
from torch import Tensor
from torch.nn.modules.loss import _Loss as Loss
from . import Metric


[文档]class CommonLoss(Metric, alias="Loss"): """ Loss Metric: 用于计算LOSS :param loss_function: 损失函数 :param item: 返回标量,而不是Tensor :param flat: 是否对预测结果进行平铺 """ def __init__(self, loss_function: Loss, item: bool = False, flat: bool = False): """ Loss Metric: 用于计算LOSS :param loss_function: 损失函数 :param item: 返回标量,而不是Tensor :param flat: 是否对预测结果进行平铺 """ super(CommonLoss, self).__init__(loss=float('inf')) self._total = 0 self._flat = flat self._loss_sum = 0. self._return_item = item self.loss_function = loss_function
[文档] def step(self, y_pred: Tensor, y: Tensor): if self._flat: shape = y_pred.shape[-1] y_pred, y = y_pred.contiguous().view((-1, shape)), y.contiguous().view(-1) loss = self.loss_function(y_pred, y) self._loss_sum += loss.item() self._total += 1
[文档] def compute(self): if self._total == 0: raise ZeroDivisionError("Loss average is not computable.") return {'loss': (self._loss_sum / self._total) if self._total != 0 else 0}
[文档] def clear(self): self._loss_sum = 0. self._total = 0