第一句子网 - 唯美句子、句子迷、好句子大全
第一句子网 > pytorch中的register_parameter()和parameter()

pytorch中的register_parameter()和parameter()

时间:2023-11-11 05:31:52

相关推荐

pytorch中的register_parameter()和parameter()

前言

这两个都是一个东西,使用上有细微差别。

对了,他两的主要作用是:将一个不可训练的类型Tensor转换成可以训练的类型parameter,并将这个parameter绑定到这个module里面,相当于变成了模型的一部分,成为了模型中可以根据训练进行变化的参数。

差别

Parameter()

ParameterTensor,即Tensor拥有的属性它都有,⽐如可以根据data来访问参数数值,⽤grad来访问参数梯度。

举例:

# 随便定义一个网络net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) # list让它可以访问weight_0 = list(net[0].parameters())[0]print(weight_0.data)print(weight_0.grad)'output'tensor([[ 0.2719, -0.0898, -0.2462, 0.0655],[-0.4669, -0.2703, 0.3230, 0.2067],[-0.2708, 0.1171, -0.0995, 0.3913]])None

register_parameter(name, param)

向我们建立的网络module添加parameter

最大的区别parameter可以通过注册网络时候的name获取。

举例

举例如下

class Example(nn.Module):def __init__(self):super(Example, self).__init__()print('看看我们的模型有哪些parameter:\t', self._parameters, end='\n')self.W1_params = nn.Parameter(torch.rand(2,3))print('增加W1后看看:',self._parameters, end='\n')self.register_parameter('W2_params' , nn.Parameter(torch.rand(2,3)))print('增加W2后看看:',self._parameters, end='\n')def forward(self, x):return x

输出:

mymodel = Example()'''看看我们的模型有哪些parameter: OrderedDict()增加W1后看看: OrderedDict([('W1_params', Parameter containing:tensor([[0.0479, 0.9264, 0.1193],[0.5004, 0.7336, 0.6464]], requires_grad=True))])增加W2后看看: OrderedDict([('W1_params', Parameter containing:tensor([[0.0479, 0.9264, 0.1193],[0.5004, 0.7336, 0.6464]], requires_grad=True)), ('W2_params', Parameter containing:tensor([[0.1028, 0.2370, 0.8500],[0.6116, 0.0463, 0.4229]], requires_grad=True))])'''

打印出来看看

for k,v in mymodel.named_parameters():print(k,v)W1_params Parameter containing:tensor([[0.4610, 0.2772, 0.5786],[0.7729, 0.0163, 0.4747]], requires_grad=True)W2_params Parameter containing:tensor([[0.4058, 0.8788, 0.2411],[0.5428, 0.9389, 0.5968]], requires_grad=True)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。