Arce 发表于 2021-7-11 17:25:22

Pytorch修改指定模块权重的方法,即 torch.Tensor.detach()和Tensor.requires_grad方法的用法

  
0、前言
  在学习pytorch的计算图和自动求导机制时,我们要想在心中建立一个“计算过程的图像”,需要深入了解其中的每个细节,这次主要说一下tensor的requires_grad参数。
无论如何定义计算过程、如何定义计算图,要谨记我们的核心目的是为了计算某些tensor的梯度。在pytorch的计算图中,其实只有两种元素:数据(tensor)和运算,运算就是加减乘除、开方、幂指对、三角函数等可求导运算,而tensor可细分为两类:叶子节点(leaf node)和非叶子节点。使用backward()函数反向传播计算tensor的梯度时,并不计算所有tensor的梯度,而是只计算满足这几个条件的tensor的梯度:1.类型为叶子节点、2.requires_grad=True、3.依赖该tensor的所有tensor的requires_grad=True
叶子节点和tensor的requires_grad参数

一、detach()那么这个函数有什么作用?
  假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法
a = A(input)
a = a.detach()

b = B(a)
loss = criterion(b, target)
loss.backward()
  以下代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None
import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True

x = x.detach()   #分离之后
x.requires_grad   #False

y = x+y      #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行

z = t.pow(y, 2)
z.backward()  #反向传播

y.grad      #tensor([4.])
x.grad      #None

二、Tensor.requires_grad属性
  既然谈到了修改模型的权重问题,那么还有一种情况是:
假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?
这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可
for param in B.parameters():
param.requires_grad = False

a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()


  
文档来源:51CTO技术博客https://blog.51cto.com/u_11495341/3036156
页: [1]
查看完整版本: Pytorch修改指定模块权重的方法,即 torch.Tensor.detach()和Tensor.requires_grad方法的用法