ltp.utils.initial 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
import torch.nn.init as init
from torch import nn


[文档]def initial_parameter(net, initial_method=None): """用于初始化 Pytorch 模型内部权重的函数 Args: net: Pytorch Module initial_method(str): 取值如下 - xavier_uniform - xavier_normal (default) - kaiming_normal, or msra - kaiming_uniform - orthogonal - sparse - normal - uniform """ if initial_method == 'xavier_uniform': init_method = init.xavier_uniform_ elif initial_method == 'xavier_normal': init_method = init.xavier_normal_ elif initial_method == 'kaiming_normal' or initial_method == 'msra': init_method = init.kaiming_normal_ elif initial_method == 'kaiming_uniform': init_method = init.kaiming_uniform_ elif initial_method == 'orthogonal': init_method = init.orthogonal_ elif initial_method == 'sparse': init_method = init.sparse_ elif initial_method == 'normal': init_method = init.normal_ elif initial_method == 'uniform': init_method = init.uniform_ else: init_method = init.xavier_normal_ def weights_init(m): # classname = m.__class__.__name__ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn if initial_method is not None: init_method(m.weight.data) else: init.xavier_normal_(m.weight.data) init.normal_(m.bias.data) elif isinstance(m, nn.LSTM): for w in m.parameters(): if len(w.data.size()) > 1: init_method(w.data) # weight else: init.normal_(w.data) # bias elif m is not None and hasattr(m, 'weight') and \ hasattr(m.weight, "requires_grad"): init_method(m.weight.data) else: for w in m.parameters(): if w.requires_grad: if len(w.data.size()) > 1: init_method(w.data) # weight else: init.normal_(w.data) # bias # print("init else") net.apply(weights_init)