math module support (#19115)
authorAlexandr Morev <hikjik@yandex.ru>
Tue, 16 Apr 2019 17:19:04 +0000 (10:19 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 16 Apr 2019 17:44:07 +0000 (10:44 -0700)
Summary:
This PR refer to issue [#19026](https://github.com/pytorch/pytorch/issues/19026)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19115

Differential Revision: D14936053

Pulled By: driazati

fbshipit-source-id: 68d5f33ced085fcb8c10ff953bc7e99df055eccc

test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp
torch/jit/__init__.py

index ff4a9b3..429cb26 100644 (file)
@@ -5159,10 +5159,76 @@ a")
 
     def test_math_ops(self):
 
-        def test_floor():
-            return math.floor(1.5)
+        def test_floor(x):
+            # type: (float) -> float
+            return math.floor(x)
+
+        def test_ceil(x):
+            # type: (float) -> float
+            return math.ceil(x)
+
+        def test_log_int(x):
+            # type: (int) -> float
+            return math.log(x)
+
+        def test_log_float(x):
+            # type: (float) -> float
+            return math.log(x)
+
+        def test_log1p_int(x):
+            # type: (int) -> float
+            return math.log1p(x)
+
+        def test_log1p_float(x):
+            # type: (float) -> float
+            return math.log1p(x)
+
+        def test_log10_int(x):
+            # type: (int) -> float
+            return math.log10(x)
 
-        self.checkScript(test_floor, ())
+        def test_log10_float(x):
+            # type: (float) -> float
+            return math.log10(x)
+
+        def test_exp_int(x):
+            # type: (int) -> float
+            return math.exp(x)
+
+        def test_exp_float(x):
+            # type: (float) -> float
+            return math.exp(x)
+
+        def test_sqrt_int(x):
+            # type: (int) -> float
+            return math.sqrt(x)
+
+        def test_sqrt_float(x):
+            # type: (float) -> float
+            return math.sqrt(x)
+
+        def test_pow_float(x, y):
+            # type: (float, float) -> float
+            return math.pow(x, y)
+
+        def test_pow_int(x, y):
+            # type: (float, int) -> float
+            return math.pow(x, y)
+
+        self.checkScript(test_floor, (1.5,))
+        self.checkScript(test_ceil, (1.5,))
+        self.checkScript(test_log_int, (2,))
+        self.checkScript(test_log_float, (2.0,))
+        self.checkScript(test_log1p_int, (1,))
+        self.checkScript(test_log1p_float, (1.0,))
+        self.checkScript(test_log10_int, (2,))
+        self.checkScript(test_log10_float, (2.0,))
+        self.checkScript(test_exp_int, (2,))
+        self.checkScript(test_exp_float, (2.0,))
+        self.checkScript(test_sqrt_int, (2,))
+        self.checkScript(test_sqrt_float, (2.0,))
+        self.checkScript(test_pow_float, (2.0, 2.0))
+        self.checkScript(test_pow_int, (2.0, 2))
 
     def test_if_nest_while(self):
         def func(a, b):
index 31142ea..5ce746d 100644 (file)
@@ -1888,11 +1888,123 @@ RegisterOperators reg2({
         }),
 
     Operator(
-        "aten::floor(float a) -> int",
+        "aten::pow(float a, float b) -> float",
+        [](Stack& stack) {
+          double a, b;
+          pop(stack, a, b);
+          push(stack, std::pow(a, b));
+          return 0;
+        }),
+    Operator(
+        "aten::pow(float a, int b) -> float",
+        [](Stack& stack) {
+          double a;
+          int b;
+          pop(stack, a, b);
+          push(stack, std::pow(a, b));
+          return 0;
+        }),
+
+    Operator(
+        "aten::floor(float a) -> float",
+        [](Stack& stack) {
+          double a;
+          pop(stack, a);
+          push(stack, std::floor(a));
+          return 0;
+        }),
+
+    Operator(
+        "aten::ceil(float a) -> float",
+        [](Stack& stack) {
+          double a;
+          pop(stack, a);
+          push(stack, std::ceil(a));
+          return 0;
+        }),
+
+    Operator(
+        "aten::log(float a) -> float",
+        [](Stack& stack) {
+          double a;
+          pop(stack, a);
+          push(stack, std::log(a));
+          return 0;
+        }),
+    Operator(
+        "aten::log(int a) -> float",
+        [](Stack& stack) {
+          int64_t a;
+          pop(stack, a);
+          push(stack, std::log(a));
+          return 0;
+        }),
+
+    Operator(
+        "aten::log1p(float a) -> float",
+        [](Stack& stack) {
+          double a;
+          pop(stack, a);
+          push(stack, std::log1p(a));
+          return 0;
+        }),
+    Operator(
+        "aten::log1p(int a) -> float",
+        [](Stack& stack) {
+          int64_t a;
+          pop(stack, a);
+          push(stack, std::log1p(a));
+          return 0;
+        }),
+
+    Operator(
+        "aten::log10(float a) -> float",
+        [](Stack& stack) {
+          double a;
+          pop(stack, a);
+          push(stack, std::log10(a));
+          return 0;
+        }),
+    Operator(
+        "aten::log10(int a) -> float",
+        [](Stack& stack) {
+          int64_t a;
+          pop(stack, a);
+          push(stack, std::log10(a));
+          return 0;
+        }),
+
+    Operator(
+        "aten::exp(float a) -> float",
+        [](Stack& stack) {
+          double a;
+          pop(stack, a);
+          push(stack, std::exp(a));
+          return 0;
+        }),
+    Operator(
+        "aten::exp(int a) -> float",
+        [](Stack& stack) {
+          int64_t a;
+          pop(stack, a);
+          push(stack, std::exp(a));
+          return 0;
+        }),
+
+    Operator(
+        "aten::sqrt(float a) -> float",
         [](Stack& stack) {
           double a;
           pop(stack, a);
-          push(stack, static_cast<int64_t>(std::floor(a)));
+          push(stack, std::sqrt(a));
+          return 0;
+        }),
+    Operator(
+        "aten::sqrt(int a) -> float",
+        [](Stack& stack) {
+          int64_t a;
+          pop(stack, a);
+          push(stack, std::sqrt(a));
           return 0;
         }),
 
@@ -1903,7 +2015,7 @@ RegisterOperators reg2({
     DEFINE_COMPARISON_OP(aten::le, a <= b),
     DEFINE_COMPARISON_OP(aten::ge, a >= b),
 
-    DEFINE_BOOL_OP(aten::__and__, a&& b),
+    DEFINE_BOOL_OP(aten::__and__, a && b),
     DEFINE_BOOL_OP(aten::__or__, a || b),
     DEFINE_BOOL_OP(aten::__xor__, a != b),
 
index 5989dbd..5a6ffcf 100644 (file)
@@ -1493,6 +1493,13 @@ def _get_builtin_table():
     _builtin_table[id(torch.nn.functional._no_grad_embedding_renorm_)] = "aten::_no_grad_embedding_renorm_"
 
     _builtin_table[id(math.floor)] = "aten::floor"
+    _builtin_table[id(math.ceil)] = "aten::ceil"
+    _builtin_table[id(math.log)] = "aten::log"
+    _builtin_table[id(math.log1p)] = "aten::log1p"
+    _builtin_table[id(math.log10)] = "aten::log10"
+    _builtin_table[id(math.exp)] = "aten::exp"
+    _builtin_table[id(math.sqrt)] = "aten::sqrt"
+    _builtin_table[id(math.pow)] = "aten::pow"
     _builtin_table[id(torch.nn.functional.interpolate)] = "aten::__interpolate"
     _builtin_table[id(torch.nn.functional.upsample_nearest)] = "aten::__upsample_nearest"
     _builtin_table[id(torch.nn.functional.upsample)] = "aten::__upsample"