add if in register_buffer like register_parameters (#16110)
authorFrankHui <757583912@qq.com>
Thu, 17 Jan 2019 19:31:31 +0000 (11:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 19:50:12 +0000 (11:50 -0800)
Summary:
without this "if", code below will throw error " Linear' object has no attribute '_buffers' "
And with this if, error would be "cannot assign buffer before Module.\_\_init\_\_() call", which I think it's more accurate, just like register_parameter.
```
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torch.nn import Module
class Linear(Module):
    def __init__(self, in_features, out_features, bias=True):

        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer('test', torch.Tensor(out_features, in_features))
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

        super(Linear, self).__init__()

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

linear = Linear(3,4)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16110

Differential Revision: D13715839

Pulled By: soumith

fbshipit-source-id: c300eff0a8655aade448354cf489a592f7db722a

torch/nn/modules/module.py

index e63222f..f436a10 100644 (file)
@@ -103,7 +103,10 @@ class Module(object):
             >>> self.register_buffer('running_mean', torch.zeros(num_features))
 
         """
-        if not isinstance(name, torch._six.string_classes):
+        if '_buffers' not in self.__dict__:
+            raise AttributeError(
+                "cannot assign buffer before Module.__init__() call")
+        elif not isinstance(name, torch._six.string_classes):
             raise TypeError("buffer name should be a string. "
                             "Got {}".format(torch.typename(name)))
         elif '.' in name: