From cad7a4b0eab0001a98ef15a787c841d52e04652c Mon Sep 17 00:00:00 2001 From: Raghavan Raman Date: Fri, 10 Sep 2021 12:35:24 -0700 Subject: [PATCH] [nnc] Added an implementation of sign op (#64033) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64033 Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D30579197 Pulled By: navahgar fbshipit-source-id: f9f7fa7f2ffa109cf4e441eb1af821b8b891d4d3 --- test/cpp/tensorexpr/test_kernel.cpp | 78 ++++++++++++++++++++++++ test/test_jit_fuser_te.py | 1 + tools/build_variables.bzl | 1 + torch/csrc/jit/tensorexpr/kernel.cpp | 80 +++++++++++++------------ torch/csrc/jit/tensorexpr/kernel.h | 4 ++ torch/csrc/jit/tensorexpr/operators/operators.h | 1 + torch/csrc/jit/tensorexpr/operators/unary.cpp | 26 ++++++++ torch/csrc/jit/tensorexpr/operators/unary.h | 15 +++++ 8 files changed, 168 insertions(+), 38 deletions(-) create mode 100644 torch/csrc/jit/tensorexpr/operators/unary.cpp create mode 100644 torch/csrc/jit/tensorexpr/operators/unary.h diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 0bcef07..6c6f47e 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -1180,6 +1180,84 @@ TEST_F(Kernel, Softmax4D) { } } +TEST_F(Kernel, SignTest) { + const auto graph_template = R"IR( + graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)): + %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0) + return (%2))IR"; + + auto run_test = [](const std::string& graph_string, const at::Tensor& input) { + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + + std::vector inputs = {input}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = at::sign(input); + ASSERT_TRUE(at::allclose(o, ref)); + }; + auto common_options = at::TensorOptions() + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + int default_input_size = 100; + for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) { + at::Tensor corner_case_inputs; + TemplateEnv env; + auto options = common_options; + switch (scalar_type) { + case ScalarType::Float: { + env.s("dtype", "Float"); + options = options.dtype(at::kFloat); + std::vector input_float = { + 0.0f, + -0.0f, + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::nanf("1"), + -std::nanf("1")}; + corner_case_inputs = at::from_blob( + input_float.data(), + {static_cast(input_float.size())}, + options); + auto rand_input = at::rand({default_input_size}, options); + auto input = at::cat({rand_input, corner_case_inputs}); + env.d("size", at::numel(input)); + const auto graph_string = format(graph_template, env); + run_test(graph_string, input); + break; + } + case ScalarType::Double: { + env.s("dtype", "Double"); + options = options.dtype(at::kDouble); + std::vector input_double = { + 0.0, + -0.0, + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::nan("1"), + -std::nan("1")}; + corner_case_inputs = at::from_blob( + input_double.data(), + {static_cast(input_double.size())}, + options); + auto rand_input = at::rand({default_input_size}, options); + auto input = at::cat({rand_input, corner_case_inputs}); + env.d("size", at::numel(input)); + const auto graph_string = format(graph_template, env); + run_test(graph_string, input); + break; + } + default: + throw unsupported_dtype(); + } + } +} + TEST_F(Kernel, InlineProducerIntoReduction) { // Inline producer (mul) into reduction (sum). const auto graph_string = R"IR( diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index b7830f9..a45e9a6 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -2030,6 +2030,7 @@ works_list = [ 'round', 'rsqrt', 'sigmoid', + 'sign', 'sin', 'sinh', 'sqrt', diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index c473157..ee3ae5a 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -310,6 +310,7 @@ core_sources_full_mobile_no_backend_interface = [ "torch/csrc/jit/tensorexpr/operators/norm.cpp", "torch/csrc/jit/tensorexpr/operators/reduction.cpp", "torch/csrc/jit/tensorexpr/operators/softmax.cpp", + "torch/csrc/jit/tensorexpr/operators/unary.cpp", "torch/csrc/jit/tensorexpr/reduction.cpp", "torch/csrc/jit/tensorexpr/registerizer.cpp", "torch/csrc/jit/tensorexpr/tensor.cpp", diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index cfb27ee..15ef427 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -41,25 +41,6 @@ static bool checkTypes(const ScalarType highType, const int typeConstraints) { return false; } -static ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) { - if (e.dtype().scalar_type() == dt) { - return e; - } - - switch (dt) { -// NOLINTNEXTLINE -#define TYPE_CASE(Type, Name) \ - case ScalarType::Name: \ - e = cast(e); \ - break; - AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); -#undef TYPE_CASE - default: - throw unsupported_dtype(); - } - return e; -} - } // namespace namespace torch { @@ -79,6 +60,25 @@ std::string buildErrorMessage(const std::string& s) { return s + ". " + generic_error_message; } +ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) { + if (e.dtype().scalar_type() == dt) { + return e; + } + + switch (dt) { +// NOLINTNEXTLINE +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + e = cast(e); \ + break; + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } + return e; +} + static int te_cuda_pointwise_loop_levels = -1; static int te_cuda_pointwise_block_count = -1; static int te_cuda_pointwise_block_size = -1; @@ -511,6 +511,24 @@ void promoteInputs(std::vector& inputs, const int typeConstraints) { } } +ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) { + auto scalarType = static_cast(e.dtype().scalar_type()); + if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) { + return e; + } + + auto defaultType = c10::typeMetaToScalarType(c10::get_default_dtype()); + + // We intend to promote Integers to floating-point types + TORCH_INTERNAL_ASSERT( + !c10::isIntegralType(defaultType, /*includeBool*/ true)); + + return Cast::make( + Dtype( + static_cast(defaultType), e.dtype().lanes()), + e); +} + ExprHandle demoteOutput( const ExprHandle& e, const c10::optional type) { @@ -746,6 +764,7 @@ std::vector TensorExprKernel::inferSizesForValue( case aten::lgamma: case aten::type_as: case aten::masked_fill: + case aten::sign: return sizesForValue(v->node()->input(0)); case aten::sub: @@ -880,25 +899,6 @@ std::vector TensorExprKernel::inferSizesForValue( } } -ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) { - auto scalarType = static_cast(e.dtype().scalar_type()); - if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) { - return e; - } - - auto defaultType = c10::typeMetaToScalarType(c10::get_default_dtype()); - - // We intend to promote Integers to floating-point types - TORCH_INTERNAL_ASSERT( - !c10::isIntegralType(defaultType, /*includeBool*/ true), - buildErrorMessage("Non-integer type")); - - return Cast::make( - Dtype( - static_cast(defaultType), e.dtype().lanes()), - e); -} - ExprHandle promoteHalfToFloat(const ExprHandle& e) { auto scalarType = static_cast(e.dtype().scalar_type()); auto floatType = static_cast(tensorexpr::ScalarType::Float); @@ -2128,6 +2128,10 @@ Tensor tensorexpr::computeOperandValue( kIntegralTypes | kFloatingPointTypes | kBoolType); } break; + case aten::sign: { + return computeSign(inputs, outputShape); + } break; + case aten::ceil: { return computeOneOperand( "aten_ceil", diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index 00faa87..803dc4a 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -76,6 +76,10 @@ void promoteInputs( std::vector& inputs, const int typeConstraints = kAllTypes); +ExprHandle promoteToDtype(ExprHandle e, ScalarType dt); + +ExprHandle promoteIntegerToDefaultType(const ExprHandle& e); + ExprHandle demoteOutput( const ExprHandle& e, const c10::optional type); diff --git a/torch/csrc/jit/tensorexpr/operators/operators.h b/torch/csrc/jit/tensorexpr/operators/operators.h index d94c958..8580ac1 100644 --- a/torch/csrc/jit/tensorexpr/operators/operators.h +++ b/torch/csrc/jit/tensorexpr/operators/operators.h @@ -5,3 +5,4 @@ #include #include #include +#include diff --git a/torch/csrc/jit/tensorexpr/operators/unary.cpp b/torch/csrc/jit/tensorexpr/operators/unary.cpp new file mode 100644 index 0000000..815ed45 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/operators/unary.cpp @@ -0,0 +1,26 @@ +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +using namespace torch::jit::tensorexpr; + +Tensor computeSign( + const std::vector& inputValues, + const std::vector& outputShape) { + return Compute( + "aten_sign", c10::fmap(outputShape), [&](ParameterList& axes) { + std::vector indices(axes.begin(), axes.end()); + std::vector inputs = { + tensorOrConstant(inputValues[0], indices)}; + auto inp = inputs[0]; + auto zero = ExprHandle(immLike(inp, 0.0f)); + auto res = (zero < inp) - (inp < zero); + return promoteToDtype(res, inp.dtype().scalar_type()); + }); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/operators/unary.h b/torch/csrc/jit/tensorexpr/operators/unary.h new file mode 100644 index 0000000..f5a0893 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/operators/unary.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +TORCH_API Tensor computeSign( + const std::vector& inputs, + const std::vector& outputShape); + +} // namespace tensorexpr +} // namespace jit +} // namespace torch -- 2.7.4