[nnc] Added an implementation of sign op (#64033)
authorRaghavan Raman <raghavanr@fb.com>
Fri, 10 Sep 2021 19:35:24 +0000 (12:35 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 23:49:04 +0000 (16:49 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64033

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D30579197

Pulled By: navahgar

fbshipit-source-id: f9f7fa7f2ffa109cf4e441eb1af821b8b891d4d3

test/cpp/tensorexpr/test_kernel.cpp
test/test_jit_fuser_te.py
tools/build_variables.bzl
torch/csrc/jit/tensorexpr/kernel.cpp
torch/csrc/jit/tensorexpr/kernel.h
torch/csrc/jit/tensorexpr/operators/operators.h
torch/csrc/jit/tensorexpr/operators/unary.cpp [new file with mode: 0644]
torch/csrc/jit/tensorexpr/operators/unary.h [new file with mode: 0644]

index 0bcef07..6c6f47e 100644 (file)
@@ -1180,6 +1180,84 @@ TEST_F(Kernel, Softmax4D) {
   }
 }
 
+TEST_F(Kernel, SignTest) {
+  const auto graph_template = R"IR(
+      graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)):
+        %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0)
+        return (%2))IR";
+
+  auto run_test = [](const std::string& graph_string, const at::Tensor& input) {
+    auto graph = std::make_shared<Graph>();
+    parseIR(graph_string, &*graph);
+
+    TensorExprKernel k(graph);
+    StmtPtr s = k.getCodeGenStmt();
+
+    std::vector<at::Tensor> inputs = {input};
+    std::vector<IValue> stack = fmap<IValue>(inputs);
+    k.run(stack);
+    auto o = stack[0].toTensor();
+    auto ref = at::sign(input);
+    ASSERT_TRUE(at::allclose(o, ref));
+  };
+  auto common_options = at::TensorOptions()
+                            .layout(at::kStrided)
+                            .device(at::kCPU)
+                            .requires_grad(false);
+  int default_input_size = 100;
+  for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) {
+    at::Tensor corner_case_inputs;
+    TemplateEnv env;
+    auto options = common_options;
+    switch (scalar_type) {
+      case ScalarType::Float: {
+        env.s("dtype", "Float");
+        options = options.dtype(at::kFloat);
+        std::vector<float> input_float = {
+            0.0f,
+            -0.0f,
+            std::numeric_limits<float>::infinity(),
+            -std::numeric_limits<float>::infinity(),
+            std::nanf("1"),
+            -std::nanf("1")};
+        corner_case_inputs = at::from_blob(
+            input_float.data(),
+            {static_cast<long>(input_float.size())},
+            options);
+        auto rand_input = at::rand({default_input_size}, options);
+        auto input = at::cat({rand_input, corner_case_inputs});
+        env.d("size", at::numel(input));
+        const auto graph_string = format(graph_template, env);
+        run_test(graph_string, input);
+        break;
+      }
+      case ScalarType::Double: {
+        env.s("dtype", "Double");
+        options = options.dtype(at::kDouble);
+        std::vector<double> input_double = {
+            0.0,
+            -0.0,
+            std::numeric_limits<double>::infinity(),
+            -std::numeric_limits<double>::infinity(),
+            std::nan("1"),
+            -std::nan("1")};
+        corner_case_inputs = at::from_blob(
+            input_double.data(),
+            {static_cast<long>(input_double.size())},
+            options);
+        auto rand_input = at::rand({default_input_size}, options);
+        auto input = at::cat({rand_input, corner_case_inputs});
+        env.d("size", at::numel(input));
+        const auto graph_string = format(graph_template, env);
+        run_test(graph_string, input);
+        break;
+      }
+      default:
+        throw unsupported_dtype();
+    }
+  }
+}
+
 TEST_F(Kernel, InlineProducerIntoReduction) {
   // Inline producer (mul) into reduction (sum).
   const auto graph_string = R"IR(
index b7830f9..a45e9a6 100644 (file)
@@ -2030,6 +2030,7 @@ works_list = [
     'round',
     'rsqrt',
     'sigmoid',
+    'sign',
     'sin',
     'sinh',
     'sqrt',
index c473157..ee3ae5a 100644 (file)
@@ -310,6 +310,7 @@ core_sources_full_mobile_no_backend_interface = [
     "torch/csrc/jit/tensorexpr/operators/norm.cpp",
     "torch/csrc/jit/tensorexpr/operators/reduction.cpp",
     "torch/csrc/jit/tensorexpr/operators/softmax.cpp",
+    "torch/csrc/jit/tensorexpr/operators/unary.cpp",
     "torch/csrc/jit/tensorexpr/reduction.cpp",
     "torch/csrc/jit/tensorexpr/registerizer.cpp",
     "torch/csrc/jit/tensorexpr/tensor.cpp",
index cfb27ee..15ef427 100644 (file)
@@ -41,25 +41,6 @@ static bool checkTypes(const ScalarType highType, const int typeConstraints) {
   return false;
 }
 
-static ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
-  if (e.dtype().scalar_type() == dt) {
-    return e;
-  }
-
-  switch (dt) {
-// NOLINTNEXTLINE
-#define TYPE_CASE(Type, Name) \
-  case ScalarType::Name:      \
-    e = cast<Type>(e);        \
-    break;
-    AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
-#undef TYPE_CASE
-    default:
-      throw unsupported_dtype();
-  }
-  return e;
-}
-
 } // namespace
 
 namespace torch {
@@ -79,6 +60,25 @@ std::string buildErrorMessage(const std::string& s) {
   return s + ". " + generic_error_message;
 }
 
+ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
+  if (e.dtype().scalar_type() == dt) {
+    return e;
+  }
+
+  switch (dt) {
+// NOLINTNEXTLINE
+#define TYPE_CASE(Type, Name) \
+  case ScalarType::Name:      \
+    e = cast<Type>(e);        \
+    break;
+    AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
+#undef TYPE_CASE
+    default:
+      throw unsupported_dtype();
+  }
+  return e;
+}
+
 static int te_cuda_pointwise_loop_levels = -1;
 static int te_cuda_pointwise_block_count = -1;
 static int te_cuda_pointwise_block_size = -1;
@@ -511,6 +511,24 @@ void promoteInputs(std::vector<ExprHandle>& inputs, const int typeConstraints) {
   }
 }
 
+ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) {
+  auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
+  if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) {
+    return e;
+  }
+
+  auto defaultType = c10::typeMetaToScalarType(c10::get_default_dtype());
+
+  // We intend to promote Integers to floating-point types
+  TORCH_INTERNAL_ASSERT(
+      !c10::isIntegralType(defaultType, /*includeBool*/ true));
+
+  return Cast::make(
+      Dtype(
+          static_cast<tensorexpr::ScalarType>(defaultType), e.dtype().lanes()),
+      e);
+}
+
 ExprHandle demoteOutput(
     const ExprHandle& e,
     const c10::optional<ScalarType> type) {
@@ -746,6 +764,7 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
     case aten::lgamma:
     case aten::type_as:
     case aten::masked_fill:
+    case aten::sign:
       return sizesForValue(v->node()->input(0));
 
     case aten::sub:
@@ -880,25 +899,6 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
   }
 }
 
-ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) {
-  auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
-  if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) {
-    return e;
-  }
-
-  auto defaultType = c10::typeMetaToScalarType(c10::get_default_dtype());
-
-  // We intend to promote Integers to floating-point types
-  TORCH_INTERNAL_ASSERT(
-      !c10::isIntegralType(defaultType, /*includeBool*/ true),
-      buildErrorMessage("Non-integer type"));
-
-  return Cast::make(
-      Dtype(
-          static_cast<tensorexpr::ScalarType>(defaultType), e.dtype().lanes()),
-      e);
-}
-
 ExprHandle promoteHalfToFloat(const ExprHandle& e) {
   auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
   auto floatType = static_cast<c10::ScalarType>(tensorexpr::ScalarType::Float);
@@ -2128,6 +2128,10 @@ Tensor tensorexpr::computeOperandValue(
           kIntegralTypes | kFloatingPointTypes | kBoolType);
     } break;
 
+    case aten::sign: {
+      return computeSign(inputs, outputShape);
+    } break;
+
     case aten::ceil: {
       return computeOneOperand(
           "aten_ceil",
index 00faa87..803dc4a 100644 (file)
@@ -76,6 +76,10 @@ void promoteInputs(
     std::vector<ExprHandle>& inputs,
     const int typeConstraints = kAllTypes);
 
+ExprHandle promoteToDtype(ExprHandle e, ScalarType dt);
+
+ExprHandle promoteIntegerToDefaultType(const ExprHandle& e);
+
 ExprHandle demoteOutput(
     const ExprHandle& e,
     const c10::optional<ScalarType> type);
index d94c958..8580ac1 100644 (file)
@@ -5,3 +5,4 @@
 #include <torch/csrc/jit/tensorexpr/operators/norm.h>
 #include <torch/csrc/jit/tensorexpr/operators/reduction.h>
 #include <torch/csrc/jit/tensorexpr/operators/softmax.h>
+#include <torch/csrc/jit/tensorexpr/operators/unary.h>
diff --git a/torch/csrc/jit/tensorexpr/operators/unary.cpp b/torch/csrc/jit/tensorexpr/operators/unary.cpp
new file mode 100644 (file)
index 0000000..815ed45
--- /dev/null
@@ -0,0 +1,26 @@
+#include <torch/csrc/jit/tensorexpr/operators/unary.h>
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+using namespace torch::jit::tensorexpr;
+
+Tensor computeSign(
+    const std::vector<ArgValue>& inputValues,
+    const std::vector<ExprHandle>& outputShape) {
+  return Compute(
+      "aten_sign", c10::fmap<DimArg>(outputShape), [&](ParameterList& axes) {
+        std::vector<ExprHandle> indices(axes.begin(), axes.end());
+        std::vector<ExprHandle> inputs = {
+            tensorOrConstant(inputValues[0], indices)};
+        auto inp = inputs[0];
+        auto zero = ExprHandle(immLike(inp, 0.0f));
+        auto res = (zero < inp) - (inp < zero);
+        return promoteToDtype(res, inp.dtype().scalar_type());
+      });
+}
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/operators/unary.h b/torch/csrc/jit/tensorexpr/operators/unary.h
new file mode 100644 (file)
index 0000000..f5a0893
--- /dev/null
@@ -0,0 +1,15 @@
+#pragma once
+
+#include <torch/csrc/jit/tensorexpr/kernel.h>
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+TORCH_API Tensor computeSign(
+    const std::vector<ArgValue>& inputs,
+    const std::vector<ExprHandle>& outputShape);
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch