Add LogSigmoid support in ONNX symbolic (#14830)
authorzrphercule <zrphercule@gmail.com>
Thu, 6 Dec 2018 22:04:44 +0000 (14:04 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 6 Dec 2018 22:17:33 +0000 (14:17 -0800)
Summary:
Add LogSigmoid:

torch.LogSigmoid(x) = onnx.Log(onnx.Sigmoid(x))
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14830

Differential Revision: D13353891

Pulled By: zrphercule

fbshipit-source-id: bf456170b9e6c4edad07b3333cd5797f8e0fa97f

test/onnx/expect/TestOperators.test_log_sigmoid.expect [new file with mode: 0644]
test/onnx/test_operators.py
torch/onnx/symbolic.py

diff --git a/test/onnx/expect/TestOperators.test_log_sigmoid.expect b/test/onnx/expect/TestOperators.test_log_sigmoid.expect
new file mode 100644 (file)
index 0000000..d09c473
--- /dev/null
@@ -0,0 +1,63 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "0"
+    output: "1"
+    op_type: "Sigmoid"
+  }
+  node {
+    input: "1"
+    output: "2"
+    op_type: "Log"
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "2"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
index c25d845..9a27011 100644 (file)
@@ -467,6 +467,10 @@ class TestOperators(TestCase):
         x = torch.randn(1, 2, 3, 4)
         self.assertONNX(torch.nn.RReLU(), x)
 
+    def test_log_sigmoid(self):
+        x = torch.randn(1, 2, 3, 4)
+        self.assertONNX(torch.nn.LogSigmoid(), x)
+
     def test_linear(self):
         x = torch.randn(3, 4)
         self.assertONNX(torch.nn.Linear(4, 5, bias=True), x)
index 745d37e..aace3b6 100644 (file)
@@ -1355,3 +1355,9 @@ def randn(g, *shapes):
 def rrelu(g, input, lower, upper, training, generator):
     p = g.op('RandomUniformLike', input, high_f=upper, low_f=lower)
     return g.op('PRelu', input, p)
+
+
+@parse_args('v')
+def log_sigmoid(g, input):
+    p = g.op('Sigmoid', input)
+    return g.op('Log', p)