add 'abs' builtin
authorMichael Kösel <michaelkoesel@gmx.de>
Wed, 3 Apr 2019 17:11:33 +0000 (10:11 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 19:47:13 +0000 (12:47 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18502

Differential Revision: D14750173

Pulled By: eellison

fbshipit-source-id: 359cf08938ada442ca1a3b3ea14022ce10229499

aten/src/ATen/core/interned_strings.h
test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/compiler.cpp

index db61cb0..9828b38 100644 (file)
@@ -74,6 +74,7 @@ namespace c10 {
   _(prim, MMBatchSide)             \
   _(prim, min)                     \
   _(prim, max)                     \
+  _(prim, abs)                     \
   _(prim, rangelist)               \
   _(aten, _grad_sum_to_size)       \
   _(aten, _ncf_unsqueeze)          \
index fef1c79..61331ca 100644 (file)
@@ -5031,6 +5031,26 @@ a")
                 code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2)
                 run_test(code)
 
+    def test_number_abs(self):
+        def func1(x):
+            # type: (float) -> float
+            return abs(x)
+
+        def func2(x):
+            # type: (int) -> int
+            return abs(x)
+
+        def func3(x):
+            return abs(x)
+
+        self.checkScript(func1, (-3.14,))
+        self.checkScript(func1, (3.14,))
+        self.checkScript(func2, (-10,))
+        self.checkScript(func2, (10,))
+        self.checkScript(func3, (torch.tensor([-5, -10, -20]),))
+        self.checkScript(func3, (torch.tensor([5, 10, 20]),))
+        self.checkScript(func3, (torch.tensor([-5, 10, -20]),))
+
     def test_number_div(self):
         self.checkScript(div_int_future, (), optimize=True)
         self.checkScript(div_float_future, (), optimize=True)
index 3337bdb..5b944b8 100644 (file)
@@ -19,6 +19,7 @@
 #include <c10/util/SmallVector.h>
 
 #include <algorithm>
+#include <cmath>
 #include <exception>
 #include <iostream>
 #include <limits>
@@ -1842,6 +1843,31 @@ RegisterOperators reg2({
     DEFINE_INT_OP(aten::__or__, a | b),
     DEFINE_INT_OP(aten::__xor__, a ^ b),
 
+    Operator(
+        "prim::abs(int x) -> int",
+        [](Stack& stack) {
+          int64_t x;
+          pop(stack, x);
+          push(stack, std::abs(x));
+          return 0;
+        }),
+    Operator(
+        "prim::abs(float x) -> float",
+        [](Stack& stack) {
+          float x;
+          pop(stack, x);
+          push(stack, std::abs(x));
+          return 0;
+        }),
+    Operator(
+        "prim::abs(Tensor x) -> Tensor",
+        [](Stack& stack) {
+          at::Tensor x;
+          pop(stack, x);
+          push(stack, x.abs());
+          return 0;
+        }),
+
     // NB: This is the python truediv operation
     Operator(
         "aten::div(int a, int b) -> float",
index 997d2df..b52b995 100644 (file)
@@ -402,6 +402,7 @@ struct Environment {
           {"hash", std::make_shared<BuiltinFunction>(aten::hash, at::nullopt)},
           {"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
           {"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
+          {"abs", std::make_shared<BuiltinFunction>(prim::abs, at::nullopt)},
           {"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
           {"ord", std::make_shared<BuiltinFunction>(aten::ord, at::nullopt)},
           {"rangelist", std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},