ltp.modules.graph 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
from typing import Optional
from torch import Tensor
from ltp.nn import MLP, Bilinear

from . import Module


[文档]class Graph(Module): def __init__(self, input_size, label_num, dropout, arc_hidden_size, rel_hidden_size, **kwargs): super(Graph, self).__init__(input_size, label_num, dropout) arc_dropout = kwargs.pop('arc_dropout', dropout) rel_dropout = kwargs.pop('rel_dropout', dropout) activation = kwargs.pop('activation', {}) self.mlp_arc_h = MLP(input_size, arc_hidden_size, arc_dropout, **activation) self.mlp_arc_d = MLP(input_size, arc_hidden_size, arc_dropout, **activation) self.mlp_rel_h = MLP(input_size, rel_hidden_size, rel_dropout, **activation) self.mlp_rel_d = MLP(input_size, rel_hidden_size, rel_dropout, **activation) self.arc_atten = Bilinear(arc_hidden_size, arc_hidden_size, 1, bias_x=True, bias_y=False, expand=True) self.rel_atten = Bilinear(rel_hidden_size, rel_hidden_size, label_num, bias_x=True, bias_y=True, expand=True)
[文档] def forward(self, embedding: Tensor, seq_lens: Tensor = None, gold: Optional = None): arc_h = self.mlp_arc_h(embedding) arc_d = self.mlp_arc_d(embedding) rel_h = self.mlp_rel_h(embedding) rel_d = self.mlp_rel_d(embedding) s_arc = self.arc_atten(arc_d, arc_h).squeeze_(1) s_rel = self.rel_atten(rel_d, rel_h).permute(0, 2, 3, 1) return s_arc, s_rel, seq_lens