#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
import torch.nn.init as init
from torch import nn
[文档]def initial_parameter(net, initial_method=None):
"""用于初始化 Pytorch 模型内部权重的函数
Args:
net: Pytorch Module
initial_method(str): 取值如下
- xavier_uniform
- xavier_normal (default)
- kaiming_normal, or msra
- kaiming_uniform
- orthogonal
- sparse
- normal
- uniform
"""
if initial_method == 'xavier_uniform':
init_method = init.xavier_uniform_
elif initial_method == 'xavier_normal':
init_method = init.xavier_normal_
elif initial_method == 'kaiming_normal' or initial_method == 'msra':
init_method = init.kaiming_normal_
elif initial_method == 'kaiming_uniform':
init_method = init.kaiming_uniform_
elif initial_method == 'orthogonal':
init_method = init.orthogonal_
elif initial_method == 'sparse':
init_method = init.sparse_
elif initial_method == 'normal':
init_method = init.normal_
elif initial_method == 'uniform':
init_method = init.uniform_
else:
init_method = init.xavier_normal_
def weights_init(m):
# classname = m.__class__.__name__
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn
if initial_method is not None:
init_method(m.weight.data)
else:
init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data)
elif isinstance(m, nn.LSTM):
for w in m.parameters():
if len(w.data.size()) > 1:
init_method(w.data) # weight
else:
init.normal_(w.data) # bias
elif m is not None and hasattr(m, 'weight') and \
hasattr(m.weight, "requires_grad"):
init_method(m.weight.data)
else:
for w in m.parameters():
if w.requires_grad:
if len(w.data.size()) > 1:
init_method(w.data) # weight
else:
init.normal_(w.data) # bias
# print("init else")
net.apply(weights_init)