#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
from torch import Tensor
from ltp.nn import MLP
from ltp import nn
from . import Module
[文档]class BiaffineCRF(Module):
def __init__(self, input_size, label_num, dropout: float = 0.2, hidden_size=None, **kwargs):
super().__init__(input_size, label_num, dropout)
activation = kwargs.pop('activation', {'LeakyReLU': {}})
self.mlp_rel_h = MLP(input_size, hidden_size, dropout=dropout, **activation)
self.mlp_rel_d = MLP(input_size, hidden_size, dropout=dropout, **activation)
self.biaffine = nn.Biaffine(hidden_size, hidden_size, label_num)
self.crf = nn.CRF(label_num)
[文档] def forward(self, inputs: Tensor, length: Tensor, gold=None):
rel_h = self.mlp_rel_h(inputs)
rel_d = self.mlp_rel_d(inputs)
logits = self.biaffine(rel_h, rel_d)
return logits, length, self.crf