Export group norm as ATen and add test (#15569)
authorLu Fang <lufang@fb.com>
Thu, 27 Dec 2018 22:42:01 +0000 (14:42 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 27 Dec 2018 22:44:29 +0000 (14:44 -0800)
Summary:
Short term solution, export group norm as an ATen op to unblock users.
Long term will add GroupNorm to onnx.

Add an end to end test for this one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15569

Differential Revision: D13554293

Pulled By: houseroad

fbshipit-source-id: b4974c9ea2a1b81338ca1e5c6747efe2715d7932

test/onnx/test_pytorch_onnx_caffe2.py
torch/onnx/symbolic.py

index 0e32f94..cc1950e 100644 (file)
@@ -150,7 +150,7 @@ class TestCaffe2Backend(unittest.TestCase):
             torch_out = (torch_out,)
 
         caffe2_out = run_embed_params(onnxir, model, input, state_dict, use_gpu)
-        for i, (x, y) in enumerate(zip(torch_out, caffe2_out)):
+        for _, (x, y) in enumerate(zip(torch_out, caffe2_out)):
             np.testing.assert_almost_equal(x.data.cpu().numpy(), y, decimal=3)
 
     def run_actual_test(self, model, train, batch_size, state_dict=None,
@@ -996,6 +996,11 @@ class TestCaffe2Backend(unittest.TestCase):
         x = torch.randn(2, 3, 4)
         self.run_model_test(ReduceSumNegativeIndices(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
 
+    def test_group_norm(self):
+        c = torch.randn(BATCH_SIZE, 6, 224)
+        model = nn.GroupNorm(3, 6)
+        self.run_model_test(model, train=True, input=c, batch_size=BATCH_SIZE)
+
 
 # a bit of metaprogramming to set up all the rnn tests
 
index 0a259be..7a8b373 100644 (file)
@@ -1155,6 +1155,12 @@ def pixel_shuffle(g, self, upscale_factor):
                  upscale_factor])
 
 
+@parse_args('v', 'i', 'v', 'v', 'f', 'i')
+def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
+    return g.op("ATen", input, weight, bias, num_groups_i=num_groups,
+                eps_f=eps, cudnn_enabled_i=cudnn_enabled, operator_s="group_norm")
+
+
 def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
                  num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None):
     weights_per_layer = 4 if has_biases else 2