improve ONNX tests on torch.Linear
authorLu Fang <lufang@fb.com>
Thu, 6 Dec 2018 01:04:39 +0000 (17:04 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 6 Dec 2018 01:07:10 +0000 (17:07 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14821

Reviewed By: zrphercule

Differential Revision: D13348773

Pulled By: houseroad

fbshipit-source-id: 611ca6e28f715e5518649c8c16f702ac3433308c

test/onnx/expect/TestOperators.test_linear.expect [new file with mode: 0644]
test/onnx/test_operators.py
test/onnx/test_pytorch_onnx_caffe2.py

diff --git a/test/onnx/expect/TestOperators.test_linear.expect b/test/onnx/expect/TestOperators.test_linear.expect
new file mode 100644 (file)
index 0000000..bb8902e
--- /dev/null
@@ -0,0 +1,105 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "input"
+    input: "1"
+    input: "2"
+    output: "3"
+    op_type: "Gemm"
+    attribute {
+      name: "alpha"
+      f: 1
+      type: FLOAT
+    }
+    attribute {
+      name: "beta"
+      f: 1
+      type: FLOAT
+    }
+    attribute {
+      name: "transB"
+      i: 1
+      type: INT
+    }
+  }
+  name: "torch-jit-export"
+  initializer {
+    dims: 5
+    dims: 4
+    data_type: FLOAT
+    name: "1"
+    raw_data: "\212\332\356>@\265u>p\303E\275 \320\306\274\354\201\221>\004\354\261\276\2746*>8\247)\276\340\035\224>\024\2446\276\200\211\312<\224\344,>D\356\257>\320\202\226\275\364\213\351>z\226\330\276\310\250\266\275\352F\377\276\000\250)=\244K\021>"
+  }
+  initializer {
+    dims: 5
+    data_type: FLOAT
+    name: "2"
+    raw_data: "\324BO\276@\245T>\350\377\245\275\374u\336\276&\212\304>"
+  }
+  input {
+    name: "input"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  input {
+    name: "1"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 5
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  input {
+    name: "2"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 5
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "3"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 5
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
index 91b51e5..c25d845 100644 (file)
@@ -467,6 +467,10 @@ class TestOperators(TestCase):
         x = torch.randn(1, 2, 3, 4)
         self.assertONNX(torch.nn.RReLU(), x)
 
+    def test_linear(self):
+        x = torch.randn(3, 4)
+        self.assertONNX(torch.nn.Linear(4, 5, bias=True), x)
+
 
 if __name__ == '__main__':
     no_onnx_dep_flag = '--no-onnx'
index f993e23..0e32f94 100644 (file)
@@ -189,8 +189,22 @@ class TestCaffe2Backend(unittest.TestCase):
                                 use_gpu=use_gpu_, example_outputs=example_outputs)
 
     def test_linear(self):
-        model = nn.Linear(1, 1)
-        input = torch.randn(1, 1, requires_grad=True)
+        class MyModel(torch.nn.Module):
+            def __init__(self):
+                super(MyModel, self).__init__()
+                self.many_fc = nn.Sequential(
+                    nn.Linear(4, 5, bias=True),
+                    nn.ReLU(inplace=True),
+                    nn.Linear(5, 6, bias=True),
+                    nn.ReLU(inplace=True),
+                    nn.Linear(6, 7, bias=True),
+                )
+
+            def forward(self, input):
+                return self.many_fc(input)
+
+        model = MyModel()
+        input = torch.randn(3, 4, requires_grad=True)
         self.run_model_test(model, train=False, batch_size=0, input=input)
 
     def test_lstm_cell(self):