From 2e6221a232d39917e2736b248c53fa85dfb8986e Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Sat, 28 Aug 2021 19:57:10 -0700 Subject: [PATCH] [nnc] Make 64-bit dimensions work (#64077) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64077 We were assuming kernel dimensions fit in 32 bits (the old fuser made this assumption too), but we should be able to support 64. ghstack-source-id: 136933272 Test Plan: unit tests; new IR level test with huge sizes Reviewed By: ZolotukhinM Differential Revision: D30596689 fbshipit-source-id: 23b7e393a2ebaecb0c391a6b1f0c4b05a98bcc94 --- test/cpp/tensorexpr/test_kernel.cpp | 40 +++++-- test/cpp/tensorexpr/test_llvm.cpp | 78 +++++++------ test/cpp/tensorexpr/test_loopnest.cpp | 7 +- test/cpp/tensorexpr/test_reductions.cpp | 1 - torch/csrc/jit/tensorexpr/block_codegen.cpp | 11 +- torch/csrc/jit/tensorexpr/bounds_inference.cpp | 2 +- torch/csrc/jit/tensorexpr/bounds_overlap.cpp | 13 ++- torch/csrc/jit/tensorexpr/cuda_codegen.cpp | 19 +-- torch/csrc/jit/tensorexpr/eval.cpp | 56 ++++++--- torch/csrc/jit/tensorexpr/eval.h | 14 +++ torch/csrc/jit/tensorexpr/expr.h | 2 +- torch/csrc/jit/tensorexpr/ir.cpp | 6 +- torch/csrc/jit/tensorexpr/ir.h | 24 ++++ torch/csrc/jit/tensorexpr/ir_printer.cpp | 8 ++ torch/csrc/jit/tensorexpr/ir_simplifier.cpp | 129 +++++++++------------ torch/csrc/jit/tensorexpr/kernel.cpp | 81 ++++++------- torch/csrc/jit/tensorexpr/kernel.h | 4 +- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 39 ++++--- torch/csrc/jit/tensorexpr/llvm_jit.h | 4 +- torch/csrc/jit/tensorexpr/loopnest.cpp | 114 +++++++++--------- .../csrc/jit/tensorexpr/mem_dependency_checker.cpp | 27 ++--- torch/csrc/jit/tensorexpr/registerizer.cpp | 5 +- torch/csrc/jit/tensorexpr/tensor.cpp | 7 +- 23 files changed, 397 insertions(+), 294 deletions(-) diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 625fadb..f4d3b16 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -198,6 +198,22 @@ TEST_F(Kernel, _3) { } } +TEST_F(Kernel, Huge) { + const auto graph_string = R"IR( + graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)): + %1 : int = prim::Constant[value=0]() + %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1) + %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + std::ostringstream oss; + oss << *k.getCodeGenStmt(); + const std::string& verification_pattern = "# CHECK: 4000000000"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + TEST_F(Kernel, ParallelStrided) { const auto graph_string = R"IR( graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu), @@ -786,9 +802,9 @@ TEST_F(Kernel, SumOneAxis) { // Check the IR we produced const std::string& verification_pattern = R"IR( -# CHECK: for (int v = 0; v < +# CHECK: for (int64_t v = 0ll; v < # CHECK-NEXT: sum -# CHECK-NEXT: for (int v_1 = 0; v_1 < +# CHECK-NEXT: for (int64_t v_1 = 0ll; v_1 < # CHECK-NEXT: sum)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -847,10 +863,10 @@ TEST_F(Kernel, SumMultipleAxes) { // Check the IR we produced const std::string& verification_pattern = R"IR( -# CHECK: int v = 0 -# CHECK: int v_1 = 0 -# CHECK: int v_2 = 0 -# CHECK: int v_3 = 0 +# CHECK: int64_t v = 0 +# CHECK: int64_t v_1 = 0 +# CHECK: int64_t v_2 = 0 +# CHECK: int64_t v_3 = 0 # CHECK: sum)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1115,8 +1131,8 @@ TEST_F(Kernel, InlineProducerIntoReduction) { // We should have only one loop in the end. const std::string& verification_pattern = R"IR( - # CHECK: for (int v = 0; v < 5; - # CHECK-NEXT: for (int v_1 = 0; v_1 < 3; + # CHECK: for (int64_t v = 0ll; v < 5 + # CHECK-NEXT: for (int64_t v_1 = 0ll; v_1 < 3 # CHECK-NEXT: sum # CHECK-NOT: for)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1154,11 +1170,11 @@ TEST_F(Kernel, InlineReductionIntoConsumer) { // We should have two loops in the end. const std::string& verification_pattern = R"IR( - # CHECK: for (int v = 0; v < 5; - # CHECK-NEXT: for (int v_1 = 0; v_1 < 3; + # CHECK: for (int64_t v = 0ll; v < 5 + # CHECK-NEXT: for (int64_t v_1 = 0ll; v_1 < 3 # CHECK-NEXT: sum - # CHECK: for (int v_2 = 0; v_2 < 5; - # CHECK-NEXT: for (int v_3 = 0; v_3 < 3; + # CHECK: for (int64_t v_2 = 0ll; v_2 < 5 + # CHECK-NEXT: for (int64_t v_3 = 0ll; v_3 < 3 # CHECK-NEXT: aten_mul # CHECK-NOT: for)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index 139763b..0e5cf5e 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -1501,42 +1501,54 @@ TEST(LLVM, RFactorVectorizedReduction) { ExpectAllNear(b_v, b_ref, 1e-5); } -TEST(LLVM, SimpleParallel) { - for (int test_cfg = 0; test_cfg < 4; test_cfg++) { - // Compute a simple operation, and try all loop-axis combination to be - // parallel or sequential. - const int M = 4; - const int N = 6; - Tensor f = Compute( - "f", {{M, "m"}, {N, "n"}}, [](const VarHandle& m, const VarHandle& n) { - return cast(m + n); - }); - LoopNest loop_nest({f}); - auto const& loops = loop_nest.getLoopStmtsFor(f); - ForPtr m = loops[0]; - ForPtr n = loops[1]; - if (test_cfg & 0x1) { - m->set_parallel(); - } - if (test_cfg & 0x2) { - n->set_parallel(); - } - loop_nest.prepareForCodegen(); - StmtPtr stmt = loop_nest.root_stmt(); - LLVMCodeGen cg(stmt, {f}); +template +static void testSimpleParallel() { + // Compute a simple operation, and try all loop-axis combination to be + // parallel or sequential. + const int M = 4; + const int N = 6; + Tensor f = Compute( + "f", {{M, "m"}, {N, "n"}}, [](const VarHandle& m, const VarHandle& n) { + return cast(m + n); + }); + LoopNest loop_nest({f}); + auto const& loops = loop_nest.getLoopStmtsFor(f); + ForPtr m = loops[0]; + ForPtr n = loops[1]; + if (outer) { + m->set_parallel(); + } + if (inner) { + n->set_parallel(); + } + loop_nest.prepareForCodegen(); + StmtPtr stmt = loop_nest.root_stmt(); + LLVMCodeGen cg(stmt, {f}); - PaddedBuffer f_v(M, N, "f_v"); - std::vector args({f_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer f_ref(M, N, "f_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - f_ref(m, n) = m + n; - } + PaddedBuffer f_v(M, N, "f_v"); + std::vector args({f_v.data()}); + int value = cg.value(args); + ASSERT_EQ(value, 0); + PaddedBuffer f_ref(M, N, "f_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + f_ref(m, n) = m + n; } - ExpectAllNear(f_v, f_ref, 1e-5); } + ExpectAllNear(f_v, f_ref, 1e-5); +} + +TEST(LLVM, SimpleParallelSS) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelSP) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelPS) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelPP) { + testSimpleParallel(); } TEST(LLVM, CompositeParallel) { diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index 28934f6..c2b33e2 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -4734,8 +4734,8 @@ TEST(LoopNest, VectorizeUse) { } const char* int64Loop = R"IR( -# CHECK: for (int64_t n = 0; n < 12; n++) { -# CHECK: b[n] = (a[n]) + 1; +# CHECK: for (int64_t n = 0ll; n < 12ll; n++) { +# CHECK: b[n] = (a[n]) + 1ll; # CHECK: } )IR"; @@ -4744,7 +4744,8 @@ TEST(LoopNest, Int64Direct) { Placeholder a("a", kLong, {N}); Placeholder b("b", kLong, {N}); VarHandle n("n", kLong); - StmtPtr s = For::make(n, 0, N, b.store({n}, a.load({n}) + LongImm::make(1l))); + StmtPtr s = For::make( + n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l))); s = IRSimplifier::simplify(s); std::ostringstream oss; oss << *s; diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index 411b58d..3d2c0ec 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -1712,7 +1712,6 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) { #CHECK-NOT: tmp )IR"; torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - SimpleIREvaluator cg(s, {b, c, m, n, k}); cg.call({in, out, M, N, K}); diff --git a/torch/csrc/jit/tensorexpr/block_codegen.cpp b/torch/csrc/jit/tensorexpr/block_codegen.cpp index 1ae3330..51b7b77 100644 --- a/torch/csrc/jit/tensorexpr/block_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/block_codegen.cpp @@ -76,7 +76,7 @@ void BlockAnalysis::visit(ForPtr v) { v->body()->accept(this); } else if (loop_options.is_gpu_thread_index()) { auto block_size = v->stop(); - block_size_ = to(block_size)->value(); + block_size_ = *intValue(block_size); v->body()->accept(this); } else { IRVisitor::visit(v); @@ -185,15 +185,14 @@ void BlockPrinter::PrintArguments(const std::unordered_set& bufs) { // The dims for the multi-dim tensors for (unsigned long d = 0; d < num_dims; d++) { - auto dim_val = to(multidimbuf->dim(d)); - this->dim_values_map.emplace(this->dim_names[d], dim_val->value()); + auto dim_val = *intValue(multidimbuf->dim(d)); + this->dim_values_map.emplace(this->dim_names[d], dim_val); } // The dimensions for the flattened tensors - auto val = to(buf->dim(0)); + auto val = *intValue(buf->dim(0)); if (block_analysis_->is_buf_store_target(buf)) { - this->dim_values_map.emplace( - this->flat_dim_names[num_dims - 1], val->value()); + this->dim_values_map.emplace(this->flat_dim_names[num_dims - 1], val); } } diff --git a/torch/csrc/jit/tensorexpr/bounds_inference.cpp b/torch/csrc/jit/tensorexpr/bounds_inference.cpp index 55dbacf..649fd0e 100644 --- a/torch/csrc/jit/tensorexpr/bounds_inference.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_inference.cpp @@ -185,7 +185,7 @@ std::vector getBoundExtents( std::vector extents; for (size_t i = 0; i < starts.size(); ++i) { ExprPtr dim = IRSimplifier::simplify( - alloc(alloc(stops[i], starts[i]), alloc(1))); + alloc(alloc(stops[i], starts[i]), immLike(stops[i], 1))); extents.push_back(dim); } diff --git a/torch/csrc/jit/tensorexpr/bounds_overlap.cpp b/torch/csrc/jit/tensorexpr/bounds_overlap.cpp index 4ac5c6b..fdfff12 100644 --- a/torch/csrc/jit/tensorexpr/bounds_overlap.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_overlap.cpp @@ -130,8 +130,8 @@ std::vector subtractBound(Bound a, Bound b, OverlapKind overlap) { auto vars = VarFinder::find(lowDiff); if (vars.size() == 1) { lowDiff = IRSimplifier::simplify(alloc( - SubstituteInClone(b.start, {{*vars.begin(), alloc(1)}}), - SubstituteInClone(a.start, {{*vars.begin(), alloc(1)}}))); + SubstituteInClone(b.start, {{*vars.begin(), immLike(b.start, 1)}}), + SubstituteInClone(a.start, {{*vars.begin(), immLike(a.start, 1)}}))); } } @@ -139,8 +139,8 @@ std::vector subtractBound(Bound a, Bound b, OverlapKind overlap) { auto vars = VarFinder::find(highDiff); if (vars.size() == 1) { highDiff = IRSimplifier::simplify(alloc( - SubstituteInClone(b.end, {{*vars.begin(), alloc(1)}}), - SubstituteInClone(a.end, {{*vars.begin(), alloc(1)}}))); + SubstituteInClone(b.end, {{*vars.begin(), immLike(b.end, 1)}}), + SubstituteInClone(a.end, {{*vars.begin(), immLike(a.end, 1)}}))); } } @@ -157,12 +157,13 @@ std::vector subtractBound(Bound a, Bound b, OverlapKind overlap) { if (hasHead) { res.emplace_back( - a.start, IRSimplifier::simplify(alloc(b.start, alloc(1)))); + a.start, + IRSimplifier::simplify(alloc(b.start, immLike(b.start, 1)))); } if (hasTail) { ExprPtr tailStart = - IRSimplifier::simplify(alloc(b.end, alloc(1))); + IRSimplifier::simplify(alloc(b.end, immLike(b.end, 1))); res.emplace_back(tailStart, a.end); } diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index b342f14..30d4207 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -45,18 +45,9 @@ class ScopedVarName { VarPtr var_ = nullptr; }; -static int as_int(ExprPtr expr) { - auto v = to(expr); - if (!v) { - throw malformed_input( - "cuda_codegen: non Int expr interpreted as int", expr); - } - - return v->value(); -} - static bool is_zero(ExprPtr expr) { - return as_int(expr) == 0; + auto v = intValue(expr); + return v && *v == 0; } static const at::cuda::NVRTC& nvrtc() { @@ -222,11 +213,11 @@ void CudaPrinter::print_flat_alloc(AllocatePtr alloc) { // TODO: this should be merged with the storage flattener. int64_t flat_size = 1; for (auto dim : dims) { - IntImmPtr dim_i = to(dim); + auto dim_i = intValue(dim); if (dim_i) { - flat_size *= dim_i->value(); + flat_size *= *dim_i; } else { - throw std::runtime_error("Only IntImm dimensions are supported for now"); + throw std::runtime_error("Only integer dimensions are supported for now"); } } os() << dtypeToCppString(alloc->dtype()) << " " << (*alloc->buffer_var()) diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index 05c3ff8..e42ce77 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -10,6 +10,17 @@ namespace tensorexpr { RegisterCodeGen ir_eval_codegen_reg("simple_ir_eval"); +int64_t Value::intValue() const { +#define TYPE_CASE(Type, Name) \ + if (dtype_ == k##Name) { \ + return int64_t{Name##values[0]}; \ + } + AT_FORALL_INT_TYPES(TYPE_CASE); +#undef TYPE_CASE + throw unsupported_dtype(); + return 0; +} + template inline typename std::enable_if::value, T>::type mod_value( T lhs, @@ -537,15 +548,16 @@ class SimpleIREvaluatorImpl : public IRVisitor { TORCH_API void visit(ForPtr v) override { ExprPtr var_node = v->var(); v->start()->accept(this); - int start = value_.as(); + auto dtype = value_.dtype(); + auto start = value_.intValue(); v->stop()->accept(this); - int stop = value_.as(); + auto stop = value_.intValue(); if (eval_context_.count(var_node)) { throw malformed_input("could not find var_node in For context", v); } - for (int i = start; i < stop; i++) { - eval_context_[var_node] = Value(i); + for (auto i = start; i < stop; i++) { + eval_context_[var_node] = Value(dtype, i); if (v->body()) { v->body()->accept(this); } @@ -555,9 +567,9 @@ class SimpleIREvaluatorImpl : public IRVisitor { TORCH_API void visit(RampPtr v) override { v->base()->accept(this); - int base = value().as(); + auto base = value().intValue(); v->stride()->accept(this); - int stride = value().as(); + auto stride = value().intValue(); int lanes = v->lanes(); std::vector values(lanes); @@ -609,6 +621,24 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } + template + std::vector toLongVec(T&& t) { + return std::vector{std::begin(t), std::end(t)}; + } + + std::vector indexVec(const Value& v) { + switch (v.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + return toLongVec(v.as_vec()); + AT_FORALL_INT_TYPES(TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } + return {}; + } + TORCH_API void visit(LoadPtr v) override { auto iter = buffer_mapping_.find(v->buf()); if (iter == buffer_mapping_.end()) { @@ -618,7 +648,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { ExprPtr flat_idx = flatten_index(v->buf()->dims(), v->indices()); flat_idx->accept(this); - std::vector index = value().as_vec(); + auto index = indexVec(value()); ScalarType v_sdtype = v->dtype().scalar_type(); switch (v_sdtype) { #define TYPE_CASE(Type, Name) \ @@ -647,7 +677,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { ExprPtr flat_idx = flatten_index(v->buf()->dims(), v->indices()); flat_idx->accept(this); - std::vector index = value().as_vec(); + auto index = indexVec(value()); ScalarType v_sdtype = v->value()->dtype().scalar_type(); switch (v_sdtype) { @@ -696,7 +726,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { buf_dtypes.push_back((int8_t)b->dtype().scalar_type()); for (ExprPtr dim_expr : b->dims()) { dim_expr->accept(this); - buf_dims.push_back(value().as()); + buf_dims.push_back(value().intValue()); } } for (ExprPtr a : v->args()) { @@ -706,7 +736,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { if (value().dtype() == kLong) { val = value().as(); } else if (value().dtype() == kInt) { - val = value().as(); + val = value().intValue(); } else { throw malformed_input( "extra_args in ExternalCalls must have int64 dtype", v); @@ -789,10 +819,10 @@ class SimpleIREvaluatorImpl : public IRVisitor { void visit(AllocatePtr v) override { BufPtr b = v->buf(); std::vector dims = b->dims(); - int total_byte_size = b->dtype().byte_size(); + int64_t total_byte_size = b->dtype().byte_size(); for (auto& dim : dims) { dim->accept(this); - total_byte_size *= value_.as(); + total_byte_size *= value_.intValue(); } auto int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int); std::unique_ptr> buffer(new std::vector(int_count)); @@ -824,7 +854,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { void visit(CondPtr v) override { v->condition()->accept(this); - if (value().as()) { + if (value().intValue()) { if (v->true_stmt()) { v->true_stmt()->accept(this); } diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index 38ec99b..494ba28 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -29,6 +29,18 @@ class Value { Intvalues.push_back(0); } + template + Value(Dtype dtype, T v) : dtype_(dtype) { +#define TYPE_CASE(Type, Name) \ + if (dtype == k##Name) { \ + Name##values.push_back(v); \ + return; \ + } + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + throw unsupported_dtype(); + } + #define VALUE_CTOR(Type, Name) \ Value(Type v) : dtype_(k##Name) { \ Name##values.push_back(v); \ @@ -50,6 +62,8 @@ class Value { template const std::vector& as_vec() const; + int64_t intValue() const; + Dtype dtype() const { return dtype_; } diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index a4f317f..fbbea12 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -319,7 +319,7 @@ class TORCH_API BufHandle : public ExprHandle { // object. For example: VarHandle x('x'); ExprHandle x2 = x; class TORCH_API VarHandle : public ExprHandle { public: - VarHandle() : ExprHandle(nullptr) {} + VarHandle() : ExprHandle() {} explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {} VarHandle(const std::string& name_hint, Dtype dtype) : ExprHandle(Var::make(name_hint, dtype)) {} diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp index f66c0c5..2680f53 100644 --- a/torch/csrc/jit/tensorexpr/ir.cpp +++ b/torch/csrc/jit/tensorexpr/ir.cpp @@ -88,17 +88,17 @@ ExprPtr flatten_index( throw malformed_input("dimensions mismatch in flatten_index"); } if (ndim == 0) { - return alloc(0); + return alloc(0); } std::vector strides(ndim); // stride[i] = stride[i+1]*dims[i+1], i < ndim-1 // stride[i] = 1, i = ndim-1 - strides[ndim - 1] = alloc(1); + strides[ndim - 1] = immLike(dims[ndim - 1], 1); for (size_t i = 1; i < ndim; i++) { strides[ndim - 1 - i] = alloc(strides[ndim - i], dims[ndim - i]); } - ExprPtr total_index = alloc(0); + ExprPtr total_index = immLike(indices[0], 0); for (const auto i : c10::irange(ndim)) { total_index = alloc(total_index, alloc(indices[i], strides[i])); } diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index 7fe1fd1..1218082 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -345,6 +345,30 @@ ExprPtr getImmediateByType(Dtype dtype, T initialVal) { } template +ExprPtr immLike(ExprPtr e, T v) { + return getImmediateByType(e->dtype(), v); +} + +template +ExprPtr immLike(ExprHandle e, T v) { + return immLike(e.node(), v); +} + +inline c10::optional intValue(ExprPtr e) { +#define TYPE_CASE(Type, Name) \ + if (auto v = to(e)) { \ + return v->value(); \ + } + AT_FORALL_INT_TYPES(TYPE_CASE); +#undef TYPE_CASE + return c10::nullopt; +} + +inline c10::optional intValue(ExprHandle e) { + return intValue(e.node()); +} + +template T immediateAs(ExprPtr e) { #define TYPE_CASE(Type, Name) \ if (Name##ImmPtr imm = to(e)) { \ diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index 2e1fc6e..ca90d99 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -206,11 +206,19 @@ static void formatImm(std::ostream& os, T v) { } } +static void formatIntSuffix(std::ostream& os, int64_t v) { + os << "ll"; +} + +template +static void formatIntSuffix(std::ostream& os, T v) {} + template < typename T, std::enable_if_t::value>* = nullptr> static void formatImm(std::ostream& os, T v) { os << +v; + formatIntSuffix(os, v); } // NOLINTNEXTLINE diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index 23216dd..6820bbb 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -430,8 +430,7 @@ ExprPtr PolynomialTransformer::mutate(AddPtr v) { // Otherwise this is a new polynomial with no scalar and two variable // terms. - return alloc( - hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); + return alloc(hasher_, immLike(v, 0), lhsTerm, rhsTerm); } // Adds are commutative. @@ -452,19 +451,17 @@ ExprPtr PolynomialTransformer::mutate(AddPtr v) { // Simple Term with a scalar and variable type. if (scalar) { return alloc( - hasher_, - scalar, - alloc(hasher_, getImmediateByType(v->dtype(), 1), variable)); + hasher_, scalar, alloc(hasher_, immLike(v, 1), variable)); } // If LHS is neither Term not Polynomial, wrap it in a Term. if (!lhsTerm && !lhsPoly) { - lhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); + lhsTerm = alloc(hasher_, immLike(v, 1), lhs_new); } // Same for RHS. if (!rhsTerm && !rhsPoly) { - rhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), 1), rhs_new); + rhsTerm = alloc(hasher_, immLike(v, 1), rhs_new); } // If we now have a poly and a term, we can insert. @@ -480,8 +477,7 @@ ExprPtr PolynomialTransformer::mutate(AddPtr v) { } // If all else fails we have a new Polynomial with two new variable Terms. - return alloc( - hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); + return alloc(hasher_, immLike(v, 0), lhsTerm, rhsTerm); } ExprPtr PolynomialTransformer::subTerms( @@ -490,7 +486,7 @@ ExprPtr PolynomialTransformer::subTerms( bool negated) { // If RHS not already negated, negate it. if (!negated) { - ExprPtr minusOne = getImmediateByType(rhs->dtype(), -1); + ExprPtr minusOne = immLike(rhs, -1); ExprPtr negateScalar = evaluateOp(alloc(minusOne, rhs->scalar())); rhs = alloc(hasher_, negateScalar, rhs->variables()); } @@ -529,8 +525,7 @@ ExprPtr PolynomialTransformer::subPolynomials( for (auto rt : rhs->variables()) { // Polynomials add their terms, so negate the RHS's Terms. - ExprPtr negated = evaluateOp( - alloc(getImmediateByType(rt->dtype(), -1), rt->scalar())); + ExprPtr negated = evaluateOp(alloc(immLike(rt, -1), rt->scalar())); TermPtr newRHS = alloc(hasher_, negated, rt->variables()); addOrUpdateTerm(varmap, newRHS); } @@ -594,7 +589,7 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { auto ret = subPolynomials(lhsPoly, rhsPoly); if (!ret) { // Cancelled out completely. - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } return ret; } @@ -605,8 +600,8 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { // Polynomial - Term. if (lhsPoly && rhsTerm) { // Negate the term. - ExprPtr negate = evaluateOp(alloc( - getImmediateByType(rhsTerm->dtype(), -1), rhsTerm->scalar())); + ExprPtr negate = + evaluateOp(alloc(immLike(rhsTerm, -1), rhsTerm->scalar())); TermPtr newTerm = alloc(hasher_, negate, rhsTerm->variables()); return insertTerm(lhsPoly, newTerm); } @@ -614,7 +609,7 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { // Term - Polynomial. if (rhsPoly && lhsTerm) { // Negate every part of the Polynomial. - ExprPtr minusOne = getImmediateByType(lhsTerm->dtype(), -1); + ExprPtr minusOne = immLike(lhsTerm, -1); ExprPtr negateScalar = evaluateOp(alloc(minusOne, rhsPoly->scalar())); std::vector variables; @@ -645,7 +640,7 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { ExprPtr newScalar = evaluateOp(alloc(lhs_new, rhsPoly->scalar())); // Negate each term in the Polynomial RHS. - ExprPtr minusOne = getImmediateByType(rhsPoly->dtype(), -1); + ExprPtr minusOne = immLike(rhsPoly, -1); std::vector variables; for (auto t : rhsPoly->variables()) { ExprPtr negate = evaluateOp(alloc(minusOne, t->scalar())); @@ -657,15 +652,14 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { if (lhsTerm && rhsScalar) { // Negate the constant. - ExprPtr negate = evaluateOp( - alloc(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); + ExprPtr negate = evaluateOp(alloc(immLike(rhs_new, -1), rhs_new)); return alloc(hasher_, negate, lhsTerm); } if (lhsScalar && rhsTerm) { // Negate the RHS Term. - ExprPtr negate = evaluateOp(alloc( - getImmediateByType(rhsTerm->scalar()->dtype(), -1), rhsTerm->scalar())); + ExprPtr negate = evaluateOp( + alloc(immLike(rhsTerm->scalar(), -1), rhsTerm->scalar())); return alloc( hasher_, lhs_new, alloc(hasher_, negate, rhsTerm->variables())); @@ -675,29 +669,24 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { if (lhsScalar) { // Create a negated term. return alloc( - hasher_, - lhs_new, - alloc(hasher_, getImmediateByType(v->dtype(), -1), rhs_new)); + hasher_, lhs_new, alloc(hasher_, immLike(v, -1), rhs_new)); } if (rhsScalar) { // Negate the scalar. - ExprPtr negate = evaluateOp( - alloc(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); + ExprPtr negate = evaluateOp(alloc(immLike(rhs_new, -1), rhs_new)); return alloc( - hasher_, - negate, - alloc(hasher_, getImmediateByType(v->dtype(), 1), lhs_new)); + hasher_, negate, alloc(hasher_, immLike(v, 1), lhs_new)); } // no scalar... if (!lhsTerm && !lhsPoly) { - lhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); + lhsTerm = alloc(hasher_, immLike(v, 1), lhs_new); } bool createdRHSnegated = false; if (!rhsTerm && !rhsPoly) { - rhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), -1), rhs_new); + rhsTerm = alloc(hasher_, immLike(v, -1), rhs_new); createdRHSnegated = true; } @@ -714,7 +703,7 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { // Insert wrapper Term into negated RHS Poly. if (rhsPoly) { CHECK(lhsTerm); - ExprPtr minusOne = getImmediateByType(rhsPoly->dtype(), -1); + ExprPtr minusOne = immLike(rhsPoly, -1); ExprPtr newScalar = evaluateOp(alloc(minusOne, rhsPoly->scalar())); // Negate each term in the Polynomial RHS. @@ -728,8 +717,7 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { return insertTerm(poly, lhsTerm); } - return alloc( - hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); + return alloc(hasher_, immLike(v, 0), lhsTerm, rhsTerm); } // Multiply two terms together, usually creating a new term with the variable @@ -930,7 +918,7 @@ ExprPtr PolynomialTransformer::mutate(MulPtr v) { // Handle special case mul by 0. if (scalar && immediateEquals(scalar, 0)) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } // Catch cases of rounding (Div(A/B) * B). @@ -994,13 +982,11 @@ ExprPtr PolynomialTransformer::mutate(MulPtr v) { // Multiplying Polynomial by variable can be wrapped in a term and handled // by polyByTerm also. if (lhsPoly) { - auto term = - alloc(hasher_, getImmediateByType(rhs_new->dtype(), 1), rhs_new); + auto term = alloc(hasher_, immLike(rhs_new, 1), rhs_new); return polyByTerm(lhsPoly, term); } if (rhsPoly) { - auto term = - alloc(hasher_, getImmediateByType(lhs_new->dtype(), 1), lhs_new); + auto term = alloc(hasher_, immLike(lhs_new, 1), lhs_new); return polyByTerm(rhsPoly, term); } @@ -1014,8 +1000,7 @@ ExprPtr PolynomialTransformer::mutate(MulPtr v) { } // Two variables, create a new Term. - return alloc( - hasher_, getImmediateByType(v->dtype(), 1), lhs_new, rhs_new); + return alloc(hasher_, immLike(v, 1), lhs_new, rhs_new); } ExprPtr factorizeDivision(ExprPtr lhs_new, ExprPtr rhs_new) { @@ -1048,10 +1033,8 @@ ExprPtr factorizeDivision(ExprPtr lhs_new, ExprPtr rhs_new) { return nullptr; } - leftScalar = evaluateOp( - alloc
(leftScalar, getImmediateByType(leftScalar->dtype(), GCD))); - rightScalar = evaluateOp( - alloc
(rightScalar, getImmediateByType(rightScalar->dtype(), GCD))); + leftScalar = evaluateOp(alloc
(leftScalar, immLike(leftScalar, GCD))); + rightScalar = evaluateOp(alloc
(rightScalar, immLike(rightScalar, GCD))); if (lhsTerm) { lhs_new = alloc(lhsTerm->hasher(), leftScalar, lhsTerm->variables()); @@ -1127,12 +1110,12 @@ ExprPtr PolynomialTransformer::mutate(ModPtr v) { // x % 1 == 0. if (rhs_new->isConstant() && immediateEquals(rhs_new, 1)) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } // x % x => 0. if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } TermPtr lhsTerm = to(lhs_new); @@ -1149,13 +1132,13 @@ ExprPtr PolynomialTransformer::mutate(ModPtr v) { if (rhs_new->isConstant() && immediateEquals( evaluateOp(alloc(lhsTerm->scalar(), rhs_new)), 0)) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } // (x * y * z) % x => 0. for (auto component : lhsTerm->variables()) { if (hasher_.hash(component) == hasher_.hash(rhs_new)) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } } @@ -1189,7 +1172,7 @@ ExprPtr PolynomialTransformer::mutate(ModPtr v) { immediateEquals( evaluateOp(alloc(lhsTerm->scalar(), rhsTerm->scalar())), 0)) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } } } @@ -1862,7 +1845,7 @@ ExprPtr polyGCD(PolynomialPtr poly) { return nullptr; } - return getImmediateByType(poly->dtype(), GCD); + return immLike(poly, GCD); } // A ModRound is a div-mod-mul in which the divisor in div and multiplier in mul @@ -1981,7 +1964,7 @@ c10::optional isModRound(TermPtr e) { } if (!scalar) { - scalar = getImmediateByType(multiplier->dtype(), 1); + scalar = immLike(multiplier, 1); } // TODO: this leaks memory! @@ -2261,23 +2244,23 @@ ExprPtr TermExpander::mutate(PolynomialPtr v) { } // Negate the term back to positive since we'll be subtracting it. - ExprPtr negated = evaluateOp(alloc( - getImmediateByType(node->scalar()->dtype(), -1), node->scalar())); + ExprPtr negated = + evaluateOp(alloc(immLike(node->scalar(), -1), node->scalar())); TermPtr newRHS = alloc(node->hasher(), negated, node->variables()); lastNode = alloc(lastNode, newRHS->accept_mutator(this)); } if (scalarWritten || immediateEquals(v->scalar(), 0)) { if (!lastNode) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } return lastNode; } if (immediateIsNegative(v->scalar())) { // Negate the scalar and subtract. - ExprPtr negated = evaluateOp( - alloc(getImmediateByType(lastNode->dtype(), -1), v->scalar())); + ExprPtr negated = + evaluateOp(alloc(immLike(lastNode, -1), v->scalar())); lastNode = alloc(lastNode, evaluateOp(negated)); } else { // we want to avoid a cast to the scalar if it would happen. @@ -2344,7 +2327,7 @@ ExprPtr TermExpander::mutate(MinTermPtr v) { ExprPtr TermExpander::mutate(RoundOffPtr v) { TermPtr term = alloc( simplifier_->hasher(), - getImmediateByType(v->dtype(), 1), + immLike(v, 1), alloc
(v->lhs(), v->rhs()), v->rhs()); return term->accept_mutator(this); @@ -2352,8 +2335,10 @@ ExprPtr TermExpander::mutate(RoundOffPtr v) { ExprPtr buf_flat_size(BufPtr v) { std::vector dims = v->dims(); - - ExprPtr flattened = getImmediateByType(kInt, 1); + if (dims.size() == 0) { + return alloc(1); + } + ExprPtr flattened = immLike(dims[0], 1); for (auto& dim : dims) { flattened = alloc(flattened, dim); } @@ -2684,7 +2669,7 @@ ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { return nullptr; } ExprPtr check_n_value = IRSimplifier::simplify( - alloc(rhsScalar, alloc(0), kGT)); + alloc(rhsScalar, immLike(rhsScalar, 0), kGT)); if (!immediateEquals(check_n_value, 1)) { return nullptr; } @@ -2719,7 +2704,7 @@ ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // range auto end = got->second.second; ExprPtr check_start = IRSimplifier::simplify( - alloc(start, alloc(0), kGE)); + alloc(start, immLike(start, 0), kGE)); ExprPtr check_end = IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); if (!check_start->isConstant() || !check_end->isConstant() || @@ -2731,7 +2716,7 @@ ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // simplify type 1) exprs: '(i+x)/n' => 'x/n' ExprPtr sign_check = - IRSimplifier::simplify(alloc(main, alloc(0), kGE)); + IRSimplifier::simplify(alloc(main, immLike(main, 0), kGE)); ExprPtr main_mod = IRSimplifier::simplify(alloc(main, rhsScalar)); ExprPtr mod_check = IRSimplifier::simplify( alloc(alloc(main_mod, end), rhsScalar, kLE)); @@ -2742,6 +2727,7 @@ ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // simplify type 2 exprs: '(i+j*n)/n' => 'j' auto ret_var = to(ret); + // FIXME: Allow any integral type. if (ret_var && ret_var->dtype() == kInt) { // retrieve j's range info auto got = var_bound_info.find(ret_var); @@ -2750,8 +2736,8 @@ ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { } // check if j is not negative - sign_check = IRSimplifier::simplify( - alloc(got->second.first, alloc(0), kGE)); + sign_check = IRSimplifier::simplify(alloc( + got->second.first, immLike(got->second.first, 0), kGE)); if (sign_check->isConstant() && immediateEquals(sign_check, 1)) { return ret_var; } @@ -2801,7 +2787,7 @@ ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { return nullptr; } ExprPtr check_n_value = IRSimplifier::simplify( - alloc(rhsScalar, alloc(0), kGT)); + alloc(rhsScalar, immLike(rhsScalar, 0), kGT)); if (!immediateEquals(check_n_value, 1)) { return nullptr; } @@ -2838,7 +2824,7 @@ ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // range auto end = got->second.second; ExprPtr check_start = IRSimplifier::simplify( - alloc(start, alloc(0), kGE)); + alloc(start, immLike(start, 0), kGE)); ExprPtr check_end = IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); if (!check_start->isConstant() || !check_end->isConstant() || @@ -2848,7 +2834,7 @@ ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // simplify type 1) exprs: '(i+x)%n' => 'i+x%n' ExprPtr sign_check = - IRSimplifier::simplify(alloc(main, alloc(0), kGE)); + IRSimplifier::simplify(alloc(main, immLike(main, 0), kGE)); ExprPtr main_mod = IRSimplifier::simplify(alloc(main, rhsScalar)); ExprPtr mod_check = IRSimplifier::simplify( alloc(alloc(main_mod, end), rhsScalar, kLE)); @@ -2860,6 +2846,7 @@ ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // simplify type 2) exprs: '(i+j*n)%n' => 'i' ExprPtr main_div = IRSimplifier::simplify(alloc
(main, rhsScalar)); auto j_var = to(main_div); + // FIXME: Allow any integral type. if (j_var && j_var->dtype() == kInt) { // retrieve j's range info auto got = var_bound_info.find(j_var); @@ -2868,8 +2855,8 @@ ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { } // check if j is not negative - sign_check = IRSimplifier::simplify( - alloc(got->second.first, alloc(0), kGE)); + sign_check = IRSimplifier::simplify(alloc( + got->second.first, immLike(got->second.first, 0), kGE)); if (sign_check->isConstant() && immediateEquals(sign_check, 1)) { return var_key; } @@ -2920,7 +2907,7 @@ ExprPtr SimplifierUnderContext::mutate(ModPtr v) { auto start = got->second.first; auto end = got->second.second; ExprPtr check_start = IRSimplifier::simplify( - alloc(start, alloc(0), kGE)); + alloc(start, immLike(start, 0), kGE)); ExprPtr check_end = IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); if (check_start->isConstant() && check_end->isConstant() && diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index f72fbf7..0d0d19e 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -202,11 +202,11 @@ c10::optional getTensorInfoJit(torch::jit::Value* v) { c10::optional getTensorInfo(BufHandle b) { std::vector dims; for (auto dim : b.dims()) { - auto val = to(dim.node()); + auto val = intValue(dim.node()); if (!val) { return c10::nullopt; } - dims.push_back(val->value()); + dims.push_back(*val); } return TensorInfo{dims, static_cast(b.dtype().scalar_type())}; } @@ -396,7 +396,7 @@ ExprHandle tensorOrConstant( return constant(v); } -size_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) { +int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) { if (idx < 0) { // Handle negative indexing idx = list_size + idx; @@ -405,7 +405,7 @@ size_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) { if (idx < 0 || idx >= list_size) { AT_ERROR("Invalid index ", idx, " for list_size", list_size); } - return static_cast(idx); + return idx; } ExprHandle broadcast(BufHandle b, const std::vector& axes) { @@ -441,8 +441,8 @@ std::vector computeIndicesToBroadcast( auto axisIt = outputAxes.rbegin(); auto sizeIt = inputSizes.rbegin(); while (sizeIt != inputSizes.rend()) { - auto const& size = sizeIt->AsNode(); - if (size && size->value() == 1) { + auto const& size = intValue(*sizeIt); + if (size && *size == 1) { bcast.emplace_back(0); } else { bcast.emplace_back(*axisIt); @@ -525,7 +525,9 @@ static at::ScalarType tensorType(BufPtr b) { std::vector bufferSizes(BufPtr b) { std::vector sizes; for (size_t i = 0; i < b->ndim(); i++) { - sizes.push_back(to(b->dim(i))->value()); + auto dim = intValue(b->dim(i)); + TORCH_INTERNAL_ASSERT(dim); + sizes.push_back(*dim); } return sizes; } @@ -543,7 +545,8 @@ ExprHandle TensorExprKernel::chunk( std::vector indices; for (size_t i = 0; i < axes.size(); ++i) { if (i == norm_dim) { - indices.push_back(axes[i] + IntImm::make((int)chunkIdx * (int)step)); + indices.push_back( + axes[i] + ExprHandle(immLike(axes[i], chunkIdx * step))); } else { indices.push_back(axes[i]); } @@ -642,7 +645,7 @@ std::vector TensorExprKernel::sizesFromVaryingShape( const c10::VaryingShape& shape) { std::vector dims; for (const auto i : c10::irange(*shape.size())) { - dims.push_back(IntImm::make(*shape[i])); + dims.push_back(*shape[i]); } return dims; } @@ -664,7 +667,7 @@ std::vector TensorExprKernel::sizesForValue( if (v->type()->isSubtypeOf(FloatType::get()) || v->type()->isSubtypeOf(IntType::get())) { - return {1}; + return {int64_t{1}}; } if (v->type()->isSubtypeOf(NoneType::get())) { return {}; @@ -820,7 +823,7 @@ std::vector TensorExprKernel::inferSizesForValue( TORCH_INTERNAL_ASSERT(n->input(1)->node()->kind() == prim::Constant); int64_t dim = n->input(1)->node()->i(attr::value); auto shape = sizesForValue(inputs[0]); - size_t norm_dim = normalizeAndCheckIndex(dim, shape.size()); + auto norm_dim = normalizeAndCheckIndex(dim, shape.size()); ExprHandle concat_dim_size = 0; for (auto input : inputs) { concat_dim_size = concat_dim_size + sizesForValue(input)[norm_dim]; @@ -889,11 +892,11 @@ ExprHandle clamp( } static bool isOne(ExprHandle e) { - auto const& n = e.AsNode(); + auto const& n = intValue(e); if (!n) { return false; } - return n->value() == 1; + return *n == 1; } std::pair, bool> broadcastShapesImpl( @@ -1150,6 +1153,7 @@ std::pair> processCatList( } return {highType, nonEmptyInputs}; } + Tensor computeCatWoConditionals( const std::vector& inputs, const std::vector& outputShape) { @@ -1184,8 +1188,7 @@ Tensor computeCatWoConditionals( } int64_t concat_dim = c10::get(arg_dim); - size_t norm_concat_dim = - normalizeAndCheckIndex(concat_dim, outputShape.size()); + auto norm_concat_dim = normalizeAndCheckIndex(concat_dim, outputShape.size()); auto gen_code_for_input = [&](const BufHandle& inp, size_t inp_pos, @@ -1196,7 +1199,8 @@ Tensor computeCatWoConditionals( std::vector store_indices(dims.size()); for (size_t i = 0; i < dims.size(); ++i) { for_vars[i] = alloc( - "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), kInt); + "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), + dims[i].dtype()); load_indices[i] = for_vars[i]; if (i == norm_concat_dim) { store_indices[i] = alloc(for_vars[i], concat_dim_size); @@ -1209,8 +1213,8 @@ Tensor computeCatWoConditionals( auto load_promoted = promoteToDtype(ExprHandle(load_expr), high_type); StmtPtr st = alloc(output_buf, store_indices, load_promoted.node()); for (size_t i = dims.size(); i > 0; --i) { - st = - alloc(for_vars[i - 1], alloc(0), dims[i - 1].node(), st); + st = alloc( + for_vars[i - 1], immLike(dims[i - 1], 0), dims[i - 1].node(), st); } return st; }; @@ -1221,7 +1225,7 @@ Tensor computeCatWoConditionals( auto input_dims = ExprVectorToExprHandleVector(non_empty_inputs[i].node()->dims()); if (concat_dim_size == nullptr) { - concat_dim_size = alloc(0); + concat_dim_size = immLike(input_dims[norm_concat_dim], 0); } block->append_stmt(gen_code_for_input( non_empty_inputs[i], i, concat_dim_size, input_dims)); @@ -1253,7 +1257,7 @@ Tensor computeCat( } int64_t dim_ = c10::get(argDim); - size_t dim = normalizeAndCheckIndex(dim_, axes.size()); + auto dim = normalizeAndCheckIndex(dim_, axes.size()); // Promote input types. // Note that we need to consider all inputs, including empty - they // also affect the resultant dtype. @@ -1273,18 +1277,18 @@ Tensor computeCat( std::vector newAxes(axes.begin(), axes.end()); ExprHandle load = promoteToDtype( tensorOrConstant(nonEmptyInputs[0], newAxes), highType); - size_t offset = to(nonEmptyInputs[0].node()->dim(dim))->value(); - newAxes[dim] = newAxes[dim] - IntImm::make(offset); + auto offset = *intValue(nonEmptyInputs[0].node()->dim(dim)); + newAxes[dim] = newAxes[dim] - ExprHandle(immLike(newAxes[dim], offset)); for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) { auto input = nonEmptyInputs[ii]; load = ifThenElse( - CompareSelect::make(axes[dim], IntImm::make(offset), kLT), + CompareSelect::make(axes[dim], offset, kLT), load, promoteToDtype(tensorOrConstant(input, newAxes), highType)); - offset += to(input.node()->dim(dim))->value(); - newAxes[dim] = axes[dim] - IntImm::make(offset); + offset += *intValue(input.node()->dim(dim)); + newAxes[dim] = axes[dim] - ExprHandle(immLike(axes[dim], offset)); } return load; @@ -2334,12 +2338,12 @@ Tensor tensorexpr::computeOperandValue( ExprHandle cur_stride = 1; std::vector dims, indices; for (size_t idx = 0; idx < view_dims.size(); idx++) { - dims.push_back(alloc(view_dims[idx])); + dims.push_back(alloc(view_dims[idx])); indices.push_back(axes[idx].node()); } ExprHandle flat_idx = ExprHandle(flatten_index(dims, indices)); std::vector orig_buf_indexes(A.ndim(), ExprHandle(0)); - ExprHandle stride = IntImm::make(1); + ExprHandle stride = ExprHandle(immLike(flat_idx, 1)); for (size_t idx = 0; idx < A.ndim(); idx++) { size_t dim_idx = A.ndim() - idx - 1; // We don't need to generate mod-div for the first dimension - @@ -2799,7 +2803,7 @@ static std::vector toExprHandles(const std::vector& sizes) { std::vector dims; dims.reserve(sizes.size()); for (auto const& size : sizes) { - dims.emplace_back(IntImm::make(size)); + dims.emplace_back(size); } return dims; } @@ -2831,8 +2835,7 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) { std::vector inputTensorDims; for (size_t i = 0; i < *tt->sizes().size(); i++) { auto const size = *tt->sizes()[i]; - inputTensorDims.emplace_back( - DimArg(IntImm::make(size), "i" + c10::to_string(i))); + inputTensorDims.emplace_back(DimArg(size, "i" + c10::to_string(i))); } auto const strides = tt->strides(); result = Compute( @@ -2841,12 +2844,11 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) { [&](const std::vector& axes) { ExprHandle idx = 0; for (size_t i = 0; i < axes.size(); i++) { - idx = idx + axes[i] * IntImm::make(*strides[i]); + idx = idx + axes[i] * *strides[i]; } return inBuffer.load(idx); }); bufs_.emplace(input, result.buf()); - bufferArgs_.emplace_back(inBuffer); break; } @@ -2956,10 +2958,10 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { return Compute( "output_1", dims, [&](const std::vector& axes_input) { std::vector axes(axes_input.begin(), axes_input.end()); - auto absolute_position = IntImm::make(0); + auto absolute_position = ExprHandle(immLike(axes[0], 0)); for (size_t i = 0; i < axes.size(); ++i) { - absolute_position = - absolute_position + (IntImm::make(default_strides[i]) * axes[i]); + absolute_position = absolute_position + + (ExprHandle(immLike(axes[i], default_strides[i])) * axes[i]); } std::vector sorted_stride_indices = reverse_sort_indices(strides); @@ -2967,10 +2969,11 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { for (size_t stride_index : sorted_stride_indices) { auto stride = strides[stride_index]; auto size = sizes[stride_index]; - auto index = Div::make(absolute_position, IntImm::make(stride)); + auto index = absolute_position / + ExprHandle(immLike(absolute_position, stride)); if (size != 1) { - absolute_position = - Mod::make(absolute_position, IntImm::make(stride)); + absolute_position = absolute_position % + ExprHandle(immLike(absolute_position, stride)); } new_axes[stride_index] = index; } @@ -2992,7 +2995,7 @@ void TensorExprKernel::bindConstant(const torch::jit::Value* v) { std::vector te_sizes; te_sizes.reserve(sizes.size()); for (auto s : sizes) { - te_sizes.push_back(IntImm::make(s)); + te_sizes.push_back(s); } BufPtr buf = alloc( diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index 99a3b12..4b92b02 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -19,7 +19,7 @@ template inline std::vector bufferSizes(const T& t) { std::vector sizes; for (size_t i = 0; i < t->ndim(); i++) { - sizes.push_back(to(t->dim(i))->value()); + sizes.push_back(*intValue(t->dim(i))); } return sizes; } @@ -62,7 +62,7 @@ ExprHandle tensorOrConstant( const ArgValue& v, const std::vector& axes); -size_t normalizeAndCheckIndex(int64_t idx, int64_t list_size); +int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size); ExprHandle broadcast(BufHandle b, const std::vector& axes); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index a93fd64..026d52b 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -275,17 +275,17 @@ class LLVMCodeGenImpl : public IRVisitor { }; extern "C" { -typedef void (*ParallelCallee)(int index, int8_t* packed_data); +typedef void (*ParallelCallee)(int64_t index, int8_t* packed_data); void DispatchParallel( int8_t* func, - int start, - int stop, + int64_t start, + int64_t stop, int8_t* packed_data) noexcept { // TODO: preserve the func type. try { ParallelCallee callee = reinterpret_cast(func); at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) { - for (int index = f_begin; index < f_end; index++) { + for (int64_t index = f_begin; index < f_end; index++) { callee(index, packed_data); } }); @@ -537,10 +537,6 @@ void LLVMCodeGenImpl::emitKernel( irb_.CreateRet(value_); - if (llvm::verifyFunction(*fn_, &llvm::outs())) { - throw std::runtime_error("Function verification failed"); - } - // print graph debug info before optimization llvm::SmallVector asmBuffer; llvm::raw_svector_ostream asmStream(asmBuffer); @@ -550,6 +546,10 @@ void LLVMCodeGenImpl::emitKernel( GRAPH_DEBUG( "\nLLVM module before optimizations\n\n", asmStream.str().str(), "\n"); + if (llvm::verifyFunction(*fn_, &llvm::outs())) { + throw std::runtime_error("Function verification failed"); + } + optimize(*module_); asmBuffer.set_size(0); @@ -1144,8 +1144,8 @@ void LLVMCodeGenImpl::visit(LoadPtr v) { // Handle the case where the load is contiguous and unmasked efficiently auto idx_ramp = to(v->flat_index()); if (idx_ramp) { - auto stride_imm = to(idx_ramp->stride()); - if (stride_imm && stride_imm->value() == 1) { + auto stride_imm = intValue(idx_ramp->stride()); + if (stride_imm && *stride_imm == 1) { v->base_handle()->accept(this); auto base = this->value_; idx_ramp->base()->accept(this); @@ -1256,7 +1256,7 @@ void LLVMCodeGenImpl::processParallelFor(ForPtr v) { // Create the new body closure code. auto func_type = - llvm::FunctionType::get(VoidTy_, {IntTy_, Int8PtrTy_}, false); + llvm::FunctionType::get(VoidTy_, {LongTy_, Int8PtrTy_}, false); llvm::Function* func = llvm::Function::Create( func_type, llvm::Function::PrivateLinkage, "func", module_.get()); auto func_body = llvm::BasicBlock::Create(getContext(), "func_body", func); @@ -1268,6 +1268,10 @@ void LLVMCodeGenImpl::processParallelFor(ForPtr v) { packed_func_args_raw, packed_caller_args->getType()); // Unpack the arguments from the opaque buffer. + if (v->var()->dtype().scalar_type() != c10::kLong) { + index = irb_.CreateIntCast( + index, dtypeToLLVM(v->var()->dtype()), v->var()->dtype().is_signed()); + } body_closure_args = unpackFuncArgs(packed_func_args, body_arg_vars.size()); // Set the codegen to the new func. // TODO: this should be replaced by RAII wrappers. @@ -1290,12 +1294,14 @@ void LLVMCodeGenImpl::processParallelFor(ForPtr v) { irb_.CreatePointerCast(packed_caller_args, Int8PtrTy_); llvm::Value* func_value = irb_.CreatePointerCast(func, Int8PtrTy_); llvm::FunctionType* dispatcher_fntype = llvm::FunctionType::get( - VoidTy_, {Int8PtrTy_, IntTy_, IntTy_, Int8PtrTy_}, false); + VoidTy_, {Int8PtrTy_, LongTy_, LongTy_, Int8PtrTy_}, false); FunctionCallee dispatcher_callee = module_->getOrInsertFunction("DispatchParallel", dispatcher_fntype); llvm::Function* dispatcher = llvm::cast(dispatcher_callee.getCallee()); dispatcher->addFnAttr(llvm::Attribute::NoUnwind); + start = irb_.CreateIntCast(start, LongTy_, true); + stop = irb_.CreateIntCast(stop, LongTy_, true); irb_.CreateCall( dispatcher, {func_value, start, stop, packed_caller_args_ptr}); value_ = llvm::ConstantInt::get(IntTy_, 0); @@ -1320,7 +1326,7 @@ void LLVMCodeGenImpl::visit(ForPtr v) { irb_.SetInsertPoint(condBlock); // Set up phi node for index variable. - auto idx = irb_.CreatePHI(IntTy_, 2); + auto idx = irb_.CreatePHI(start->getType(), 2); idx->addIncoming(start, preheader); if (!varToVal_.count(v->var())) { varToVal_.emplace(v->var(), idx); @@ -1345,7 +1351,8 @@ void LLVMCodeGenImpl::visit(ForPtr v) { body = irb_.GetInsertBlock(); // Increment the index variable and branch back to loop test. - auto inc = irb_.CreateAdd(idx, llvm::ConstantInt::getSigned(IntTy_, 1)); + auto inc = + irb_.CreateAdd(idx, llvm::ConstantInt::getSigned(start->getType(), 1)); irb_.CreateBr(condBlock); idx->addIncoming(inc, body); @@ -1430,8 +1437,8 @@ void LLVMCodeGenImpl::visit(StorePtr v) { // Handle the case where the store is contiguous and unmasked efficiently auto idx_ramp = to(v->flat_index()); if (idx_ramp) { - auto stride_imm = to(idx_ramp->stride()); - if (stride_imm && stride_imm->value() == 1) { + auto stride_imm = intValue(idx_ramp->stride()); + if (stride_imm && *stride_imm == 1) { idx_ramp->base()->accept(this); auto first_idx = value_; diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.h b/torch/csrc/jit/tensorexpr/llvm_jit.h index 8585900..a837899 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.h +++ b/torch/csrc/jit/tensorexpr/llvm_jit.h @@ -20,8 +20,8 @@ namespace tensorexpr { extern "C" { void DispatchParallel( int8_t* func, - int start, - int stop, + int64_t start, + int64_t stop, int8_t* packed_data) noexcept; } diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index d3a4b91..11020cc 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -127,8 +127,8 @@ class Vectorizer : public IRMutator { ExprPtr start = v->start(); ExprPtr stop = v->stop(); - IntImmPtr start_imm = to(start); - IntImmPtr stop_imm = to(stop); + auto start_imm = intValue(start); + auto stop_imm = intValue(stop); if (!start_imm) { throw std::runtime_error( "Can't vectorize due to non-constant loop start!"); @@ -140,8 +140,8 @@ class Vectorizer : public IRMutator { } var_ = var; - start_ = start_imm; - lanes_ = stop_imm->value(); + start_ = immLike(start, *start_imm); + lanes_ = *stop_imm; StmtPtr new_body = body->accept_mutator(this); if (new_body == body) { @@ -531,11 +531,11 @@ class FunctionInliner : public IRMutator { if (auto index_var = to(i)) { index_vars_.insert(index_var); producer_index_vars_.push_back(index_var); - } else if (to(i) != nullptr) { + } else if (intValue(i)) { // If the index can be a constant, then that dimension must have size 1 // (since we don't support in-place writes). Resolves issue 52581. TORCH_INTERNAL_ASSERT( - to(i)->value() == 0, + *intValue(i) == 0, "Constant index impression should always be zero"); producer_index_vars_.push_back(nullptr); } else { @@ -553,8 +553,7 @@ class FunctionInliner : public IRMutator { ExprPtr func_caller_param = dims.at(i); if (func_callee_arg == nullptr) { TORCH_INTERNAL_ASSERT( - to(func_caller_param) != nullptr && - to(func_caller_param)->value() == 0, + intValue(func_caller_param) && *intValue(func_caller_param) == 0, "We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0"); continue; } @@ -1140,7 +1139,7 @@ bool LoopNest::optimizeConditionals() { // only include the RHS of the conditions in the if-then-else expressions // we need to start with `0` which is the initial bound, given that we // only handle normalized loops (check for this is done below). - std::vector comp_values = {alloc(0)}; + std::vector comp_values; std::vector sub_exprs; auto ifthenelse_exprs = NodeFinder::find(store); if (ifthenelse_exprs.empty()) { @@ -1155,6 +1154,8 @@ bool LoopNest::optimizeConditionals() { ifthenelse_exprs.front(), &cond_var, &comp_values, &sub_exprs)) { continue; } + TORCH_INTERNAL_ASSERT(comp_values.size() >= 1); + comp_values.insert(comp_values.begin(), immLike(comp_values[0], 0)); auto fors = getLoopStmtsFor(store); if (cond_var != fors.back()->var()) { @@ -1290,10 +1291,10 @@ void LoopNest::vectorizeInnerLoops() { } void LoopNest::sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { - if (to(f->start()) && to(f->stop())) { - int start_val = to(f->start())->value(); - int stop_val = to(f->stop())->value(); - int size_val = stop_val - start_val; + if (intValue(f->start()) && intValue(f->stop())) { + auto start_val = *intValue(f->start()); + auto stop_val = *intValue(f->stop()); + auto size_val = stop_val - start_val; if (factor >= size_val) { *head = f; *tail = nullptr; @@ -1311,7 +1312,7 @@ void LoopNest::sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { } ExprPtr head_end = alloc( - alloc(f->start(), alloc(factor)), f->stop(), true); + alloc(f->start(), immLike(f->stop(), factor)), f->stop(), true); *head = alloc(f->var(), f->start(), head_end, Stmt::clone(f->body())); p->insert_stmt_before(*head, f); @@ -1330,10 +1331,10 @@ void LoopNest::sliceHead(ForPtr f, int factor) { } void LoopNest::sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { - if (to(f->start()) && to(f->stop())) { - int start_val = to(f->start())->value(); - int stop_val = to(f->stop())->value(); - int size_val = stop_val - start_val; + if (intValue(f->start()) && intValue(f->stop())) { + auto start_val = *intValue(f->start()); + auto stop_val = *intValue(f->stop()); + auto size_val = stop_val - start_val; if (factor >= size_val) { *head = nullptr; *tail = f; @@ -1351,7 +1352,7 @@ void LoopNest::sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { } ExprPtr tail_start = alloc( - f->start(), alloc(f->stop(), alloc(factor)), true); + f->start(), alloc(f->stop(), immLike(f->stop(), factor)), true); *tail = alloc(f->var(), tail_start, f->stop(), Stmt::clone(f->body())); p->insert_stmt_after(*tail, f); @@ -1390,17 +1391,17 @@ void LoopNest::splitWithTail( } bool tail_is_needed = true; - if (to(f->start()) && to(f->stop())) { - int start_val = to(f->start())->value(); - int stop_val = to(f->stop())->value(); - int size_val = stop_val - start_val; - int tail_size = size_val % factor; + if (intValue(f->start()) && intValue(f->stop())) { + auto const start_val = *intValue(f->start()); + auto const stop_val = *intValue(f->stop()); + auto const size_val = stop_val - start_val; + auto const tail_size = size_val % factor; if (tail_size == 0) { tail_is_needed = false; } } - IntImmPtr factor_expr = alloc(factor); + ExprPtr factor_expr = immLike(f->stop(), factor); ExprPtr size = alloc(f->stop(), f->start()); ExprPtr split_count = alloc
(size, factor_expr); ExprPtr tail_size = alloc(size, factor_expr); @@ -1423,7 +1424,7 @@ void LoopNest::splitWithTail( StmtPtr body_tail = SubstituteInClone(f->body(), {{f->var(), combined_index2}}); - *tail = alloc(i_tail, alloc(0), tail_size, body_tail); + *tail = alloc(i_tail, immLike(tail_size, 0), tail_size, body_tail); p->insert_stmt_after(*tail, f); } else { @@ -1433,10 +1434,11 @@ void LoopNest::splitWithTail( StmtPtr body_inner = Substitute(f->removeBody(), {{f->var(), combined_index1}}); - *inner = alloc(i_inner, alloc(0), factor_expr, body_inner); + *inner = + alloc(i_inner, immLike(factor_expr, 0), factor_expr, body_inner); // The input loop `f` will be the outer loop after split. f->set_var(i_outer); - f->set_start(alloc(0)); + f->set_start(immLike(split_count, 0)); f->set_stop(split_count); f->set_body(*inner); } @@ -1458,20 +1460,20 @@ void LoopNest::splitWithMask(ForPtr f, int factor, ForPtr* inner) { ExprPtr start = IRSimplifier::simplify(f->start()); ExprPtr stop = IRSimplifier::simplify(f->stop()); if (start->isConstant() && stop->isConstant()) { - int start_val = immediateAs(start); - int stop_val = immediateAs(stop); - int size_val = stop_val - start_val; - int tail_size = size_val % factor; + auto start_val = *intValue(start); + auto stop_val = *intValue(stop); + auto size_val = stop_val - start_val; + auto tail_size = size_val % factor; if (tail_size == 0) { tail_is_needed = false; } } - IntImmPtr factor_expr = alloc(factor); + auto factor_expr = immLike(f->stop(), factor); ExprPtr size = alloc(f->stop(), f->start()); // split_count = (size + factor - 1) / factor ExprPtr split_count = alloc
( - alloc(alloc(size, factor_expr), alloc(1)), factor_expr); + alloc(alloc(size, factor_expr), immLike(size, 1)), factor_expr); const std::string& loop_var_name = f->var()->name_hint(); Dtype loop_var_dtype = f->var()->dtype(); @@ -1487,8 +1489,8 @@ void LoopNest::splitWithMask(ForPtr f, int factor, ForPtr* inner) { // TODO: is it ok that we're doing it eagerly? In the other implementation we // are only materializing predicates at the last, lowering, step. if (tail_is_needed) { - IntImmPtr start = to(f->start()); - if (!start || start->value() != 0) { + auto start = intValue(f->start()); + if (!start || *start != 0) { throw unimplemented_lowering(); } @@ -1499,10 +1501,11 @@ void LoopNest::splitWithMask(ForPtr f, int factor, ForPtr* inner) { } body_inner = Substitute(body_inner, {{f->var(), combined_index}}); - *inner = alloc(i_inner, alloc(0), factor_expr, body_inner); + *inner = + alloc(i_inner, immLike(factor_expr, 0), factor_expr, body_inner); // The input loop `f` will be the outer loop after split. f->set_var(i_outer); - f->set_start(alloc(0)); + f->set_start(immLike(split_count, 0)); f->set_stop(split_count); f->set_body(*inner); } @@ -2177,7 +2180,7 @@ bool LoopNest::normalize(ForPtr f) { {{f->var(), (VarHandle(f->var()) + ExprHandle(f->start())).node()}}); f->set_body(IRSimplifier::simplify(for_body_normalized)); f->set_stop(IRSimplifier::simplify(alloc(f->stop(), f->start()))); - f->set_start(alloc(0)); + f->set_start(immLike(f->stop(), 0)); return true; } @@ -2242,7 +2245,7 @@ bool LoopNest::flatten(const std::vector& loops, ForPtr* flattened) { normalized_loops[0]->var()->name_hint() + "_flat", normalized_loops[0]->var()->dtype()); VarMapping var_mapping; - ExprPtr stop = alloc(1); + ExprPtr stop = immLike(flat_var, 1); for (size_t i = 0; i < normalized_loops.size(); ++i) { size_t idx = normalized_loops.size() - i - 1; auto curr_loop = normalized_loops[idx]; @@ -2255,7 +2258,7 @@ bool LoopNest::flatten(const std::vector& loops, ForPtr* flattened) { Substitute(normalized_loops.back()->removeBody(), var_mapping); normalized_loops.front()->set_var(flat_var); - normalized_loops.front()->set_start(alloc(0)); + normalized_loops.front()->set_start(immLike(stop, 0)); normalized_loops.front()->set_stop(stop); normalized_loops.front()->set_body(flattened_body); *flattened = normalized_loops.front(); @@ -2357,7 +2360,7 @@ void LoopNest::compressBuffer(BufPtr buf, StmtPtr stmt) { std::vector new_dims(buf->dims()); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i]) { - new_dims[i] = alloc(1); + new_dims[i] = immLike(buf->dims()[i], 1); } } buf->set_dims(new_dims); @@ -2368,7 +2371,7 @@ void LoopNest::compressBuffer(BufPtr buf, StmtPtr stmt) { std::vector new_indices(indices); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i]) { - new_indices[i] = alloc(0); + new_indices[i] = immLike(indices[i], 0); } } return new_indices; @@ -2652,12 +2655,13 @@ LoopNest::AccessResult LoopNest::cacheAccesses( // Determine the size of the cache, and create a loop var for each dimension. for (size_t i = 0; i < info.start.size(); ++i) { - ExprPtr dim = IRSimplifier::simplify( - alloc(alloc(info.stop[i], info.start[i]), alloc(1))); + ExprPtr dim = IRSimplifier::simplify(alloc( + alloc(info.stop[i], info.start[i]), immLike(info.stop[i], 1))); tmp_dims.push_back(dim); - new_loop_vars.push_back(alloc(var_names[i % var_names.size()], kInt)); + new_loop_vars.push_back( + alloc(var_names[i % var_names.size()], info.stop[i]->dtype())); new_loop_vars_expr.push_back(new_loop_vars[i]); } @@ -2708,8 +2712,8 @@ LoopNest::AccessResult LoopNest::cacheAccesses( tmp_buf, new_loop_vars_expr, getImmediateByType(tmp_buf->dtype(), 0)); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { - tmp_init = - alloc(new_loop_vars[i], alloc(0), tmp_dims[i], tmp_init); + tmp_init = alloc( + new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_init); } if (is_block) { @@ -2730,7 +2734,7 @@ LoopNest::AccessResult LoopNest::cacheAccesses( for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = alloc( - new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); + new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store); } if (is_block) { @@ -2749,7 +2753,7 @@ LoopNest::AccessResult LoopNest::cacheAccesses( for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = alloc( - new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); + new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store); } if (is_block) { @@ -2766,7 +2770,7 @@ LoopNest::AccessResult LoopNest::cacheAccesses( for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = alloc( - new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); + new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store); } if (is_block) { @@ -2914,7 +2918,8 @@ void LoopNest::computeAt(StmtPtr s, ForPtr f) { std::vector temp_indices(dims.size()); for (const auto i : c10::irange(dims.size())) { // TODO: Use name-hint of the producer indices instead of 'idx' - temp_indices[i] = alloc(std::string("idx") + c10::to_string(i), kInt); + temp_indices[i] = + alloc(std::string("idx") + c10::to_string(i), dims[i]->dtype()); } // Prepare substitute rules for constructing the temp statement from the prod @@ -2955,7 +2960,10 @@ void LoopNest::computeAt(StmtPtr s, ForPtr f) { // dimensions in reversed order. size_t dim_idx = dims.size() - 1 - i; bd = alloc( - to(temp_indices[dim_idx]), alloc(0), dims[dim_idx], bd); + to(temp_indices[dim_idx]), + immLike(dims[dim_idx], 0), + dims[dim_idx], + bd); } // Add constructed stmts to the consumer loop diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp index 8f6f2b1..e1688e3 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp @@ -185,13 +185,13 @@ void AccessInfo::dumpDOT(std::ostream& os) const { if (bounds_.size() > 0) { for (size_t i = 0; i < bounds_.size() - 1; ++i) { os << *IRSimplifier::simplify( - alloc(bounds_[i].end, alloc(1))) + alloc(bounds_[i].end, immLike(bounds_[i].end, 1))) << ", "; } size_t i = bounds_.size() - 1; os << *IRSimplifier::simplify( - alloc(bounds_[i].end, alloc(1))); + alloc(bounds_[i].end, immLike(bounds_[i].end, 1))); os << "]\"\n "; } if (isWrite()) { @@ -632,7 +632,7 @@ bool executionSafetyCheck( // Invert the startDiff so mod works. if (diffNegative != strideNegative) { startDiff = - IRSimplifier::simplify(alloc(alloc(0), startDiff)); + IRSimplifier::simplify(alloc(immLike(startDiff, 0), startDiff)); } // If both accesses have the same stride, and the difference in start @@ -650,7 +650,7 @@ bool executionSafetyCheck( CompareSelectOperation op = strideNegative ? kLT : kGT; ExprPtr check = IRSimplifier::simplify( - alloc(startDiff, alloc(0), op)); + alloc(startDiff, immLike(startDiff, 0), op)); // If the start difference modulo the minimum stride is offset from that // stride, then the ranges have distinct strides. @@ -731,7 +731,7 @@ void MemDependencyChecker::visit(ForPtr v) { for (const auto i : c10::irange(indices.size())) { VarFinder vf; if (vf.find(indices[i]).count(var) == 0) { - loopIndicesStride[i] = alloc(0); + loopIndicesStride[i] = immLike(indices[i], 0); } else { // If we've previously swapped the start and end of this bound, we // should apply the substitution to the reverse of the bounds. @@ -740,19 +740,19 @@ void MemDependencyChecker::visit(ForPtr v) { SubstituteInClone(info->bounds()[i].end, {{var, v->start()}})); info->bounds()[i].start = IRSimplifier::simplify(SubstituteInClone( info->bounds()[i].start, - {{var, alloc(v->stop(), alloc(1))}})); + {{var, alloc(v->stop(), immLike(v->stop(), 1))}})); } else { info->bounds()[i].start = IRSimplifier::simplify( SubstituteInClone(info->bounds()[i].start, {{var, v->start()}})); info->bounds()[i].end = IRSimplifier::simplify(SubstituteInClone( info->bounds()[i].end, - {{var, alloc(v->stop(), alloc(1))}})); + {{var, alloc(v->stop(), immLike(v->stop(), 1))}})); } ExprPtr zeroStep = indices[i]; ExprPtr oneStep = SubstituteInClone( - indices[i], {{var, alloc(var, alloc(1))}}); + indices[i], {{var, alloc(var, immLike(var, 1))}}); loopIndicesStride[i] = IRSimplifier::simplify(alloc(oneStep, zeroStep)); @@ -785,7 +785,7 @@ void MemDependencyChecker::visit(ForPtr v) { bound.start = IRSimplifier::simplify( SubstituteInClone(bound.start, {{var, v->start()}})); bound.end = IRSimplifier::simplify(SubstituteInClone( - bound.end, {{var, alloc(v->stop(), alloc(1))}})); + bound.end, {{var, alloc(v->stop(), immLike(v->stop(), 1))}})); // If the start < end then swap the order of the bound. ExprPtr diff = @@ -1037,8 +1037,8 @@ void MemDependencyChecker::insertBuffers( IndexBounds bounds; for (auto d : b->dims()) { bounds.push_back( - {alloc(0), - IRSimplifier::simplify(alloc(d, alloc(1)))}); + {immLike(d, 0), + IRSimplifier::simplify(alloc(d, immLike(d, 1)))}); } auto info = std::make_shared(nextAccess_++, type, nullptr, var, bounds); @@ -1126,8 +1126,9 @@ void MemDependencyChecker::visit(AllocatePtr v) { // avoid failing the bound check. But this is not the correct approach and // should be fixed. ExprPtr flat_size = buf_flat_size(v->buf()); - flat_size = IRSimplifier::simplify(alloc(flat_size, alloc(1))); - bounds.push_back({alloc(0), flat_size}); + flat_size = + IRSimplifier::simplify(alloc(flat_size, immLike(flat_size, 1))); + bounds.push_back({immLike(flat_size, 0), flat_size}); auto info = std::make_shared( nextAccess_++, AccessType::Alloc, nullptr, var, bounds); diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index bc26581..8684f2a 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -18,7 +18,7 @@ void AccessInfo::addStore(StorePtr store, const std::shared_ptr& scope) { last_usage_ = store; store_cost_ = - IRSimplifier::simplify(alloc(store_cost_, alloc(1))); + IRSimplifier::simplify(alloc(store_cost_, immLike(store_cost_, 1))); stores_.push_back(store); conditionId_ = scope->conditionId(); @@ -34,7 +34,8 @@ void AccessInfo::addLoad( first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : usage; last_usage_ = usage; - load_cost_ = IRSimplifier::simplify(alloc(load_cost_, alloc(1))); + load_cost_ = + IRSimplifier::simplify(alloc(load_cost_, immLike(load_cost_, 1))); loads_.push_back(load); conditionId_ = scope->conditionId(); diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp index ea3902d..7a219fe 100644 --- a/torch/csrc/jit/tensorexpr/tensor.cpp +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -31,8 +31,8 @@ StmtPtr Tensor::constructStmt( for (const auto i : c10::irange(reduce_ndim)) { // Going in reverse order: from innermost loop to the outermost size_t dim_index = reduce_ndim - i - 1; - s = alloc( - reduce_args[dim_index], alloc(0), reduce_dims[dim_index], s); + auto const& dim = reduce_dims[dim_index]; + s = alloc(reduce_args[dim_index], immLike(dim, 0), dim, s); } if (init_expr) { StorePtr init_stmt = alloc(buf(), indices, init_expr); @@ -43,7 +43,8 @@ StmtPtr Tensor::constructStmt( for (const auto i : c10::irange(ndim)) { // Going in reverse order: from innermost loop to the outermost size_t dim_index = ndim - i - 1; - s = alloc(args[dim_index], alloc(0), buf()->dim(dim_index), s); + auto const& dim = buf()->dim(dim_index); + s = alloc(args[dim_index], immLike(dim, 0), dim, s); } return s; } -- 2.7.4