}
}
+TEST_F(Kernel, SignTest) {
+ const auto graph_template = R"IR(
+ graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)):
+ %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0)
+ return (%2))IR";
+
+ auto run_test = [](const std::string& graph_string, const at::Tensor& input) {
+ auto graph = std::make_shared<Graph>();
+ parseIR(graph_string, &*graph);
+
+ TensorExprKernel k(graph);
+ StmtPtr s = k.getCodeGenStmt();
+
+ std::vector<at::Tensor> inputs = {input};
+ std::vector<IValue> stack = fmap<IValue>(inputs);
+ k.run(stack);
+ auto o = stack[0].toTensor();
+ auto ref = at::sign(input);
+ ASSERT_TRUE(at::allclose(o, ref));
+ };
+ auto common_options = at::TensorOptions()
+ .layout(at::kStrided)
+ .device(at::kCPU)
+ .requires_grad(false);
+ int default_input_size = 100;
+ for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) {
+ at::Tensor corner_case_inputs;
+ TemplateEnv env;
+ auto options = common_options;
+ switch (scalar_type) {
+ case ScalarType::Float: {
+ env.s("dtype", "Float");
+ options = options.dtype(at::kFloat);
+ std::vector<float> input_float = {
+ 0.0f,
+ -0.0f,
+ std::numeric_limits<float>::infinity(),
+ -std::numeric_limits<float>::infinity(),
+ std::nanf("1"),
+ -std::nanf("1")};
+ corner_case_inputs = at::from_blob(
+ input_float.data(),
+ {static_cast<long>(input_float.size())},
+ options);
+ auto rand_input = at::rand({default_input_size}, options);
+ auto input = at::cat({rand_input, corner_case_inputs});
+ env.d("size", at::numel(input));
+ const auto graph_string = format(graph_template, env);
+ run_test(graph_string, input);
+ break;
+ }
+ case ScalarType::Double: {
+ env.s("dtype", "Double");
+ options = options.dtype(at::kDouble);
+ std::vector<double> input_double = {
+ 0.0,
+ -0.0,
+ std::numeric_limits<double>::infinity(),
+ -std::numeric_limits<double>::infinity(),
+ std::nan("1"),
+ -std::nan("1")};
+ corner_case_inputs = at::from_blob(
+ input_double.data(),
+ {static_cast<long>(input_double.size())},
+ options);
+ auto rand_input = at::rand({default_input_size}, options);
+ auto input = at::cat({rand_input, corner_case_inputs});
+ env.d("size", at::numel(input));
+ const auto graph_string = format(graph_template, env);
+ run_test(graph_string, input);
+ break;
+ }
+ default:
+ throw unsupported_dtype();
+ }
+ }
+}
+
TEST_F(Kernel, InlineProducerIntoReduction) {
// Inline producer (mul) into reduction (sum).
const auto graph_string = R"IR(
'round',
'rsqrt',
'sigmoid',
+ 'sign',
'sin',
'sinh',
'sqrt',
"torch/csrc/jit/tensorexpr/operators/norm.cpp",
"torch/csrc/jit/tensorexpr/operators/reduction.cpp",
"torch/csrc/jit/tensorexpr/operators/softmax.cpp",
+ "torch/csrc/jit/tensorexpr/operators/unary.cpp",
"torch/csrc/jit/tensorexpr/reduction.cpp",
"torch/csrc/jit/tensorexpr/registerizer.cpp",
"torch/csrc/jit/tensorexpr/tensor.cpp",
return false;
}
-static ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
- if (e.dtype().scalar_type() == dt) {
- return e;
- }
-
- switch (dt) {
-// NOLINTNEXTLINE
-#define TYPE_CASE(Type, Name) \
- case ScalarType::Name: \
- e = cast<Type>(e); \
- break;
- AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
-#undef TYPE_CASE
- default:
- throw unsupported_dtype();
- }
- return e;
-}
-
} // namespace
namespace torch {
return s + ". " + generic_error_message;
}
+ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
+ if (e.dtype().scalar_type() == dt) {
+ return e;
+ }
+
+ switch (dt) {
+// NOLINTNEXTLINE
+#define TYPE_CASE(Type, Name) \
+ case ScalarType::Name: \
+ e = cast<Type>(e); \
+ break;
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
+#undef TYPE_CASE
+ default:
+ throw unsupported_dtype();
+ }
+ return e;
+}
+
static int te_cuda_pointwise_loop_levels = -1;
static int te_cuda_pointwise_block_count = -1;
static int te_cuda_pointwise_block_size = -1;
}
}
+ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) {
+ auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
+ if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) {
+ return e;
+ }
+
+ auto defaultType = c10::typeMetaToScalarType(c10::get_default_dtype());
+
+ // We intend to promote Integers to floating-point types
+ TORCH_INTERNAL_ASSERT(
+ !c10::isIntegralType(defaultType, /*includeBool*/ true));
+
+ return Cast::make(
+ Dtype(
+ static_cast<tensorexpr::ScalarType>(defaultType), e.dtype().lanes()),
+ e);
+}
+
ExprHandle demoteOutput(
const ExprHandle& e,
const c10::optional<ScalarType> type) {
case aten::lgamma:
case aten::type_as:
case aten::masked_fill:
+ case aten::sign:
return sizesForValue(v->node()->input(0));
case aten::sub:
}
}
-ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) {
- auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
- if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) {
- return e;
- }
-
- auto defaultType = c10::typeMetaToScalarType(c10::get_default_dtype());
-
- // We intend to promote Integers to floating-point types
- TORCH_INTERNAL_ASSERT(
- !c10::isIntegralType(defaultType, /*includeBool*/ true),
- buildErrorMessage("Non-integer type"));
-
- return Cast::make(
- Dtype(
- static_cast<tensorexpr::ScalarType>(defaultType), e.dtype().lanes()),
- e);
-}
-
ExprHandle promoteHalfToFloat(const ExprHandle& e) {
auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
auto floatType = static_cast<c10::ScalarType>(tensorexpr::ScalarType::Float);
kIntegralTypes | kFloatingPointTypes | kBoolType);
} break;
+ case aten::sign: {
+ return computeSign(inputs, outputShape);
+ } break;
+
case aten::ceil: {
return computeOneOperand(
"aten_ceil",
std::vector<ExprHandle>& inputs,
const int typeConstraints = kAllTypes);
+ExprHandle promoteToDtype(ExprHandle e, ScalarType dt);
+
+ExprHandle promoteIntegerToDefaultType(const ExprHandle& e);
+
ExprHandle demoteOutput(
const ExprHandle& e,
const c10::optional<ScalarType> type);
#include <torch/csrc/jit/tensorexpr/operators/norm.h>
#include <torch/csrc/jit/tensorexpr/operators/reduction.h>
#include <torch/csrc/jit/tensorexpr/operators/softmax.h>
+#include <torch/csrc/jit/tensorexpr/operators/unary.h>
--- /dev/null
+#include <torch/csrc/jit/tensorexpr/operators/unary.h>
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+using namespace torch::jit::tensorexpr;
+
+Tensor computeSign(
+ const std::vector<ArgValue>& inputValues,
+ const std::vector<ExprHandle>& outputShape) {
+ return Compute(
+ "aten_sign", c10::fmap<DimArg>(outputShape), [&](ParameterList& axes) {
+ std::vector<ExprHandle> indices(axes.begin(), axes.end());
+ std::vector<ExprHandle> inputs = {
+ tensorOrConstant(inputValues[0], indices)};
+ auto inp = inputs[0];
+ auto zero = ExprHandle(immLike(inp, 0.0f));
+ auto res = (zero < inp) - (inp < zero);
+ return promoteToDtype(res, inp.dtype().scalar_type());
+ });
+}
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
--- /dev/null
+#pragma once
+
+#include <torch/csrc/jit/tensorexpr/kernel.h>
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+TORCH_API Tensor computeSign(
+ const std::vector<ArgValue>& inputs,
+ const std::vector<ExprHandle>& outputShape);
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch