Add torch.nn.RReLU support in symbolic (#14781)
authorzrphercule <zrphercule@gmail.com>
Wed, 5 Dec 2018 20:59:44 +0000 (12:59 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 21:10:07 +0000 (13:10 -0800)
Summary:
Now we support exporting torch.nn.RReLU in onnx.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14781

Reviewed By: houseroad

Differential Revision: D13343872

Pulled By: zrphercule

fbshipit-source-id: 1e96b957de4fc2f5ba3959d42329807975419ae3

test/onnx/expect/TestOperators.test_rrelu.expect [new file with mode: 0644]
test/onnx/test_operators.py
torch/onnx/symbolic.py

diff --git a/test/onnx/expect/TestOperators.test_rrelu.expect b/test/onnx/expect/TestOperators.test_rrelu.expect
new file mode 100644 (file)
index 0000000..f0fe234
--- /dev/null
@@ -0,0 +1,74 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "input"
+    output: "1"
+    op_type: "RandomUniformLike"
+    attribute {
+      name: "high"
+      f: 0.333333343
+      type: FLOAT
+    }
+    attribute {
+      name: "low"
+      f: 0.125
+      type: FLOAT
+    }
+  }
+  node {
+    input: "input"
+    input: "1"
+    output: "2"
+    op_type: "PRelu"
+  }
+  name: "torch-jit-export"
+  input {
+    name: "input"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "2"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
index 0221699..91b51e5 100644 (file)
@@ -463,6 +463,11 @@ class TestOperators(TestCase):
         x = torch.randn(1, 2, 3, 4)
         self.assertONNX(lambda x: torch.randn(1, 2, 3, 4) + x, x)
 
+    def test_rrelu(self):
+        x = torch.randn(1, 2, 3, 4)
+        self.assertONNX(torch.nn.RReLU(), x)
+
+
 if __name__ == '__main__':
     no_onnx_dep_flag = '--no-onnx'
     _onnx_dep = no_onnx_dep_flag not in common.UNITTEST_ARGS
index ba79a64..745d37e 100644 (file)
@@ -37,9 +37,10 @@ _sum = sum
 
 
 def _parse_arg(value, desc):
+    if desc == 'none':
+        return value
     if desc == 'v' or not _is_value(value):
         return value
-
     if value.node().kind() != 'onnx::Constant':
         raise RuntimeError("ONNX symbolic expected a constant value in the trace")
     tval = value.node()['value']
@@ -1348,3 +1349,9 @@ def randn(g, *shapes):
     shapes_list = list(shapes)
     shape = _maybe_get_const(shapes_list[0], "is")
     return g.op('RandomNormal', shape_i=shape)
+
+
+@parse_args('v', 'f', 'f', 'i', 'none')
+def rrelu(g, input, lower, upper, training, generator):
+    p = g.op('RandomUniformLike', input, high_f=upper, low_f=lower)
+    return g.op('PRelu', input, p)