From da4ff17eee9f27bb5582418fb168c57c7e6b3e67 Mon Sep 17 00:00:00 2001 From: Alexandr Morev Date: Tue, 16 Apr 2019 10:19:04 -0700 Subject: [PATCH] math module support (#19115) Summary: This PR refer to issue [#19026](https://github.com/pytorch/pytorch/issues/19026) Pull Request resolved: https://github.com/pytorch/pytorch/pull/19115 Differential Revision: D14936053 Pulled By: driazati fbshipit-source-id: 68d5f33ced085fcb8c10ff953bc7e99df055eccc --- test/test_jit.py | 72 ++++++++++++++++++++- torch/csrc/jit/register_prim_ops.cpp | 118 ++++++++++++++++++++++++++++++++++- torch/jit/__init__.py | 7 +++ 3 files changed, 191 insertions(+), 6 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index ff4a9b3..429cb26 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5159,10 +5159,76 @@ a") def test_math_ops(self): - def test_floor(): - return math.floor(1.5) + def test_floor(x): + # type: (float) -> float + return math.floor(x) + + def test_ceil(x): + # type: (float) -> float + return math.ceil(x) + + def test_log_int(x): + # type: (int) -> float + return math.log(x) + + def test_log_float(x): + # type: (float) -> float + return math.log(x) + + def test_log1p_int(x): + # type: (int) -> float + return math.log1p(x) + + def test_log1p_float(x): + # type: (float) -> float + return math.log1p(x) + + def test_log10_int(x): + # type: (int) -> float + return math.log10(x) - self.checkScript(test_floor, ()) + def test_log10_float(x): + # type: (float) -> float + return math.log10(x) + + def test_exp_int(x): + # type: (int) -> float + return math.exp(x) + + def test_exp_float(x): + # type: (float) -> float + return math.exp(x) + + def test_sqrt_int(x): + # type: (int) -> float + return math.sqrt(x) + + def test_sqrt_float(x): + # type: (float) -> float + return math.sqrt(x) + + def test_pow_float(x, y): + # type: (float, float) -> float + return math.pow(x, y) + + def test_pow_int(x, y): + # type: (float, int) -> float + return math.pow(x, y) + + self.checkScript(test_floor, (1.5,)) + self.checkScript(test_ceil, (1.5,)) + self.checkScript(test_log_int, (2,)) + self.checkScript(test_log_float, (2.0,)) + self.checkScript(test_log1p_int, (1,)) + self.checkScript(test_log1p_float, (1.0,)) + self.checkScript(test_log10_int, (2,)) + self.checkScript(test_log10_float, (2.0,)) + self.checkScript(test_exp_int, (2,)) + self.checkScript(test_exp_float, (2.0,)) + self.checkScript(test_sqrt_int, (2,)) + self.checkScript(test_sqrt_float, (2.0,)) + self.checkScript(test_pow_float, (2.0, 2.0)) + self.checkScript(test_pow_int, (2.0, 2)) def test_if_nest_while(self): def func(a, b): diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 31142ea..5ce746d 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1888,11 +1888,123 @@ RegisterOperators reg2({ }), Operator( - "aten::floor(float a) -> int", + "aten::pow(float a, float b) -> float", + [](Stack& stack) { + double a, b; + pop(stack, a, b); + push(stack, std::pow(a, b)); + return 0; + }), + Operator( + "aten::pow(float a, int b) -> float", + [](Stack& stack) { + double a; + int b; + pop(stack, a, b); + push(stack, std::pow(a, b)); + return 0; + }), + + Operator( + "aten::floor(float a) -> float", + [](Stack& stack) { + double a; + pop(stack, a); + push(stack, std::floor(a)); + return 0; + }), + + Operator( + "aten::ceil(float a) -> float", + [](Stack& stack) { + double a; + pop(stack, a); + push(stack, std::ceil(a)); + return 0; + }), + + Operator( + "aten::log(float a) -> float", + [](Stack& stack) { + double a; + pop(stack, a); + push(stack, std::log(a)); + return 0; + }), + Operator( + "aten::log(int a) -> float", + [](Stack& stack) { + int64_t a; + pop(stack, a); + push(stack, std::log(a)); + return 0; + }), + + Operator( + "aten::log1p(float a) -> float", + [](Stack& stack) { + double a; + pop(stack, a); + push(stack, std::log1p(a)); + return 0; + }), + Operator( + "aten::log1p(int a) -> float", + [](Stack& stack) { + int64_t a; + pop(stack, a); + push(stack, std::log1p(a)); + return 0; + }), + + Operator( + "aten::log10(float a) -> float", + [](Stack& stack) { + double a; + pop(stack, a); + push(stack, std::log10(a)); + return 0; + }), + Operator( + "aten::log10(int a) -> float", + [](Stack& stack) { + int64_t a; + pop(stack, a); + push(stack, std::log10(a)); + return 0; + }), + + Operator( + "aten::exp(float a) -> float", + [](Stack& stack) { + double a; + pop(stack, a); + push(stack, std::exp(a)); + return 0; + }), + Operator( + "aten::exp(int a) -> float", + [](Stack& stack) { + int64_t a; + pop(stack, a); + push(stack, std::exp(a)); + return 0; + }), + + Operator( + "aten::sqrt(float a) -> float", [](Stack& stack) { double a; pop(stack, a); - push(stack, static_cast(std::floor(a))); + push(stack, std::sqrt(a)); + return 0; + }), + Operator( + "aten::sqrt(int a) -> float", + [](Stack& stack) { + int64_t a; + pop(stack, a); + push(stack, std::sqrt(a)); return 0; }), @@ -1903,7 +2015,7 @@ RegisterOperators reg2({ DEFINE_COMPARISON_OP(aten::le, a <= b), DEFINE_COMPARISON_OP(aten::ge, a >= b), - DEFINE_BOOL_OP(aten::__and__, a&& b), + DEFINE_BOOL_OP(aten::__and__, a && b), DEFINE_BOOL_OP(aten::__or__, a || b), DEFINE_BOOL_OP(aten::__xor__, a != b), diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 5989dbd..5a6ffcf 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1493,6 +1493,13 @@ def _get_builtin_table(): _builtin_table[id(torch.nn.functional._no_grad_embedding_renorm_)] = "aten::_no_grad_embedding_renorm_" _builtin_table[id(math.floor)] = "aten::floor" + _builtin_table[id(math.ceil)] = "aten::ceil" + _builtin_table[id(math.log)] = "aten::log" + _builtin_table[id(math.log1p)] = "aten::log1p" + _builtin_table[id(math.log10)] = "aten::log10" + _builtin_table[id(math.exp)] = "aten::exp" + _builtin_table[id(math.sqrt)] = "aten::sqrt" + _builtin_table[id(math.pow)] = "aten::pow" _builtin_table[id(torch.nn.functional.interpolate)] = "aten::__interpolate" _builtin_table[id(torch.nn.functional.upsample_nearest)] = "aten::__upsample_nearest" _builtin_table[id(torch.nn.functional.upsample)] = "aten::__upsample" -- 2.7.4