Tensor.data与Tensor.detach()的区别


参考资料:

  1. pytorch中的detach()和data
  2. Differences between .data and .detach #6990
import torch

Tensor.detach()

a = torch.tensor([1,2,3.], requires_grad = True)
out = a.sigmoid()
c = out.detach()
c.zero_()
tensor([0., 0., 0.])
out   # modified by c.zero_() !!
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
out.sum().backward()  # Requires the original value of out, but that was overwritten by c.zero_()
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-25-ada644350cc4> in <module>
----> 1 out.sum().backward()  # Requires the original value of out, but that was overwritten by c.zero_()


F:\Anaconda3\envs\pyg\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    222 
    223     def register_hook(self, hook):


F:\Anaconda3\envs\pyg\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
--> 132         allow_unreachable=True)  # allow_unreachable flag
    133 
    134 


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of SigmoidBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

以及,

Tensor.data

a = torch.tensor([1,2,3.], requires_grad = True)
out = a.sigmoid()
c = out.data
c.zero_()
tensor([0., 0., 0.])
out  # out  was modified by c.zero_()
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
out.sum().backward()
a.grad  # The result is very, very wrong because `out` changed!
tensor([0., 0., 0.])

这里说的很清楚,但事实上还有一个微小差别,请看下面的实验。

首先,笔者的配置是pytorch 1.1.0和python 3。

那detach和data两个区别到底是什么呢?首先都是无梯度的纯tensor,如下,

t = torch.tensor([0., 1.], requires_grad=True)
t2 = t.detach()
t3 = t.data
print(t2.requires_grad, t3.requires_grad)  # ouptut: False, False
False False

事实上,这两个新的tensor t2和t3和原始tensort都共享一块数据内存。

其次,detach之后,在in-place的操作,并不会一定报错,而且,有些情况下,梯度反传计算是完全正确的!这是为什么呢?其实是基于一个很简单的道理,在计算梯度的时候,分两种计算方式,第一种,

sigmoid

以及第二种,

二次函数

你们一定看出来区别了,就是bp的时候自变量不一样,第一种是y,第二种是x。做个实验看看,

x = torch.tensor(0., requires_grad=True)
y = x.sigmoid()
y.detach().zero_()
print(y)
y.backward()
tensor(0., grad_fn=<SigmoidBackward>)



---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-34-bac852285750> in <module>
      3 y.detach().zero_()
      4 print(y)
----> 5 y.backward()


F:\Anaconda3\envs\pyg\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    222 
    223     def register_hook(self, hook):


F:\Anaconda3\envs\pyg\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
--> 132         allow_unreachable=True)  # allow_unreachable flag
    133 
    134 


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor []], which is output 0 of SigmoidBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

原因:这里修改了y的data,而bp的计算依赖这个data,因此报错,那换成另外一个操作呢,

x = torch.tensor(1., requires_grad=True)
y = x ** 2
y.detach().zero_()
print(y)
y.backward()
print(x.grad)
tensor(0., grad_fn=<PowBackward0>)
tensor(2.)

这里成功输出如上。

总之,用detach还是很保险的,有些情况下是能够报错的,但并不全都是。事实上,直接修改图中的节点很少用到(有的话请在评论区给出),一般都是用来计算一些其他的辅助变量,用以debug,这是比较多。鉴于笔者能力有限,难免有疏漏之处,欢迎大家多交流、讨论、指正。

再来一个例子加深印象

x = torch.tensor([1,1,1],dtype=torch.float,requires_grad=True)
y = torch.tensor([2,4,6],dtype=torch.float,requires_grad=True).view(-1,1)
y
tensor([[2.],
        [4.],
        [6.]], grad_fn=<ViewBackward>)
z = torch.matmul(x,y)
tmp = y.detach()
tmp.mul_(2)
tensor([[ 4.],
        [ 8.],
        [12.]])
y.requires_grad
True
z.backward()
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-43-40c0c9b0bbab> in <module>
----> 1 z.backward()


F:\Anaconda3\envs\pyg\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    222 
    223     def register_hook(self, hook):


F:\Anaconda3\envs\pyg\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
--> 132         allow_unreachable=True)  # allow_unreachable flag
    133 
    134 


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 1]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
x.grad

import torch
x = torch.tensor([1,1,1],dtype=torch.float,requires_grad=True)
y = torch.tensor([2,4,6],dtype=torch.float,requires_grad=True)
y
tensor([2., 4., 6.], requires_grad=True)
z = torch.add(x,y)
tmp = y.detach()
tmp.mul_(2)
tensor([ 4.,  8., 12.])
y
tensor([ 4.,  8., 12.], requires_grad=True)
z.sum().backward()
x.grad
tensor([1., 1., 1.])

文章作者: CarlYoung
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 CarlYoung !
  目录