[REFACTOR][ARITH] Remove legacy compute_expr.h (#5738)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sat, 6 Jun 2020 20:23:31 +0000 (13:23 -0700)
committerGitHub <noreply@github.com>
Sat, 6 Jun 2020 20:23:31 +0000 (13:23 -0700)
Replaces most of the ComptuteReduce using foldl.

23 files changed:
include/tvm/tir/op.h
src/arith/compute_expr.h [deleted file]
src/relay/op/tensor/transform.cc
src/target/llvm/codegen_llvm.h
src/target/source/codegen_c.cc
src/target/spirv/codegen_spirv.cc
src/te/operation/compute_op.cc
src/te/operation/op_util.cc
src/te/operation/tensor_compute_op.cc
src/te/schedule/message_passing.cc
src/te/schedule/schedule_dataflow_rewrite.cc
src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
src/tir/ir/buffer.cc
src/tir/transforms/arg_binder.cc
src/tir/transforms/inject_double_buffer.cc
src/tir/transforms/inject_virtual_thread.cc
src/tir/transforms/lower_thread_allreduce.cc
src/tir/transforms/lower_warp_memory.cc
src/tir/transforms/storage_access.cc
src/tir/transforms/storage_flatten.cc
src/tir/transforms/storage_rewrite.cc
src/tir/transforms/unroll_loop.cc
src/tir/transforms/vectorize_loop.cc

index a4748d5..8d2add2 100644 (file)
@@ -463,6 +463,7 @@ TVM_DLL PrimExpr isinf(PrimExpr x);
  * \brief sum of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \return The result.
  */
 TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis);
 
@@ -477,6 +478,7 @@ TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis);
  * \brief logical Or of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \return The result.
  */
 TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis);
 
@@ -484,6 +486,7 @@ TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis);
  * \brief max of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \return The result.
  */
 TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis);
 
@@ -491,6 +494,7 @@ TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis);
  * \brief max of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \return The result.
  */
 TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis);
 
@@ -498,6 +502,7 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis);
  * \brief product of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \return The result.
  */
 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis);
 
@@ -659,6 +664,17 @@ inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
 inline bool is_const(const PrimExpr& x);
 
 /*!
+ * \brief Left fold.
+ * \param freduce The reduction function.
+ * \param init_value The initial value.
+ * \param values The values to be folded.
+ * \return The result.
+ * \tparam FReduce The type of the reduction.
+ */
+template <typename FReduce>
+inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values);
+
+/*!
  * \brief Check whether x is a constant power of two
  * If x is power of two, write the power to the shift.
  *
@@ -762,6 +778,15 @@ inline PrimExpr make_zero(DataType t) {
   }
   return make_const(t, 0);
 }
+
+template <typename FReduce>
+inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values) {
+  for (PrimExpr val : values) {
+    init_value = freduce(init_value, val);
+  }
+  return init_value;
+}
+
 }  // namespace tir
 
 // additional const expression overloading
diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h
deleted file mode 100644 (file)
index 39530ff..0000000
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file compute_expr.h
- * \brief Utility to invoke certan compute operations.
- */
-#ifndef TVM_ARITH_COMPUTE_EXPR_H_
-#define TVM_ARITH_COMPUTE_EXPR_H_
-
-#include <tvm/tir/expr.h>
-#include <tvm/tir/op.h>
-
-#include <algorithm>
-#include <limits>
-
-namespace tvm {
-namespace arith {
-
-/*!
- * \brief Compute the expression with the given binary op.
- * \param lhs The left operand
- * \param rhs The right operand
- * \tparam Op the computation operator
- * \return The result.
- */
-template <typename OP>
-inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) {
-  return OP::make(lhs, rhs);
-}
-
-/*!
- * \brief Compute an reduction with Op
- * \param values The input values.
- * \param empty_value The value when return if it is empty, can be Expr()
- *        which will cause an error to be rasied.
- * \tparam Op The computation operator
- * \return The result.
- */
-template <typename Op>
-inline PrimExpr ComputeReduce(const Array<PrimExpr>& values, PrimExpr empty_value);
-
-template <>
-inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) {
-  return a + b;
-}
-
-template <>
-inline PrimExpr Compute<tir::SubNode>(PrimExpr a, PrimExpr b) {
-  return a - b;
-}
-
-template <>
-inline PrimExpr Compute<tir::MulNode>(PrimExpr a, PrimExpr b) {
-  return a * b;
-}
-
-template <>
-inline PrimExpr Compute<tir::DivNode>(PrimExpr a, PrimExpr b) {
-  return truncdiv(a, b);
-}
-
-template <>
-inline PrimExpr Compute<tir::ModNode>(PrimExpr a, PrimExpr b) {
-  return truncmod(a, b);
-}
-
-template <>
-inline PrimExpr Compute<tir::MaxNode>(PrimExpr a, PrimExpr b) {
-  return max(a, b);
-}
-
-template <>
-inline PrimExpr Compute<tir::MinNode>(PrimExpr a, PrimExpr b) {
-  return min(a, b);
-}
-
-template <typename Op>
-inline PrimExpr ComputeReduce(const Array<PrimExpr>& values, PrimExpr empty_value) {
-  if (values.size() == 0U) {
-    CHECK(empty_value.defined());
-    return empty_value;
-  }
-  PrimExpr res = values[0];
-  for (size_t i = 1; i < values.size(); ++i) {
-    res = Compute<Op>(res, values[i]);
-  }
-  return res;
-}
-
-}  // namespace arith
-}  // namespace tvm
-#endif  // TVM_ARITH_COMPUTE_EXPR_H_
index a80bb31..0275a89 100644 (file)
@@ -38,7 +38,6 @@
 
 #include <vector>
 
-#include "../../../arith/compute_expr.h"
 #include "../../transforms/infer_layout_util.h"
 #include "../../transforms/pattern_util.h"
 #include "../op_common.h"
index 4522c15..0bca2a1 100644 (file)
@@ -42,7 +42,6 @@
 #include <utility>
 #include <vector>
 
-#include "../../arith/compute_expr.h"
 #include "../../runtime/thread_storage_scope.h"
 #include "../../tir/transforms/ir_util.h"
 #include "llvm_common.h"
index 69dab62..9255d7c 100644 (file)
@@ -25,7 +25,6 @@
 #include <cctype>
 #include <iomanip>
 
-#include "../../arith/compute_expr.h"
 #include "../../arith/pattern_match.h"
 
 namespace tvm {
index e76e8be..364a62f 100644 (file)
 
 #include <tvm/runtime/container.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
 
 #include <string>
 
-#include "../../arith/compute_expr.h"
-
 namespace tvm {
 namespace codegen {
 
index 048285d..cc843a4 100644 (file)
@@ -34,7 +34,6 @@
 #include <unordered_set>
 #include <utility>
 
-#include "../../arith/compute_expr.h"
 #include "../../arith/interval_set.h"
 #include "../schedule/message_passing.h"
 #include "op_util.h"
@@ -593,8 +592,8 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range
     }
   }
 
-  return IfThenElseNode::make(arith::ComputeReduce<tir::OrNode>(conds, const_true(1)), update,
-                              body);
+  auto cond = foldl([](PrimExpr a, PrimExpr b) { return a || b; }, const_false(1), conds);
+  return IfThenElseNode::make(cond, update, body);
 }
 
 }  // namespace te
index 5b200ac..341e761 100644 (file)
@@ -29,7 +29,6 @@
 
 #include <string>
 
-#include "../../arith/compute_expr.h"
 #include "../../runtime/thread_storage_scope.h"
 #include "../schedule/message_passing.h"
 
index 236aff6..96ddb36 100644 (file)
@@ -29,7 +29,6 @@
 
 #include <unordered_set>
 
-#include "../../arith/compute_expr.h"
 #include "./compute_op.h"
 #include "./op_util.h"
 
index 4f0e982..55593be 100644 (file)
@@ -26,8 +26,6 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
 
-#include "../../arith/compute_expr.h"
-
 namespace tvm {
 namespace te {
 
index 9561263..cfd8b26 100644 (file)
  */
 #include <tvm/te/operation.h>
 #include <tvm/te/schedule.h>
+#include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 
 #include <unordered_set>
 
-#include "../../arith/compute_expr.h"
 #include "../../tir/transforms/ir_util.h"
 #include "message_passing.h"
 #include "operation_inline.h"
@@ -89,13 +89,14 @@ PrimExpr InjectPredicate(const Array<PrimExpr>& predicates, PrimExpr body) {
   using tir::SelectNode;
   if (predicates.size() == 0) return body;
   const ReduceNode* reduce = body.as<ReduceNode>();
+  auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
+
   if (reduce) {
     auto n = make_object<ReduceNode>(*reduce);
-    n->condition = n->condition && arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr());
+    n->condition = foldl(fand, n->condition, predicates);
     return PrimExpr(n);
   }
-  return SelectNode::make(arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr()), body,
-                          make_zero(body.dtype()));
+  return SelectNode::make(foldl(fand, const_true(1), predicates), body, make_zero(body.dtype()));
 }
 
 // Replace data flow appears in all stages given the tensor change.
@@ -707,7 +708,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f
   const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
   CHECK(reduce) << "Can only rfactor non-inline reductions";
   predicates.push_back(reduce->condition);
-  PrimExpr predicate = likely(arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr()));
+  auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
+
+  PrimExpr predicate = likely(foldl(fand, const_true(1), predicates));
 
   std::unordered_map<const VarNode*, PrimExpr> vsub;
 
index 84166d1..46fc91b 100644 (file)
@@ -36,7 +36,6 @@
 
 #include <unordered_map>
 
-#include "../../arith/compute_expr.h"
 #include "../../runtime/thread_storage_scope.h"
 
 namespace tvm {
index 8b98ed9..3a60521 100644 (file)
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/buffer.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
 
 #include <iterator>
 #include <stack>
 
-#include "../../arith/compute_expr.h"
-
 namespace tvm {
 namespace tir {
 
@@ -367,7 +366,8 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
     int highest_dim = 0;
     extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
   } else {
-    extent = arith::ComputeReduce<tir::MulNode>(self->shape, PrimExpr()) - offset;
+    auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
+    extent = foldl(fmul, make_const(DataType::Int(32), 1), self->shape) - offset;
   }
   PrimExpr elem_offset = self->elem_offset + offset;
   if (content_lanes > 1) {
index 01a6996..384d459 100644 (file)
@@ -25,8 +25,8 @@
 
 #include <tvm/runtime/device_api.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
 
-#include "../../arith/compute_expr.h"
 #include "ir_util.h"
 
 namespace tvm {
@@ -225,8 +225,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
                    << " expected to be compact array";
     if (conds.size() != 0) {
       auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str());
-      Stmt check = AssertStmtNode::make(arith::ComputeReduce<tir::AndNode>(conds, PrimExpr()),
-                                        stride_msg, EvaluateNode::make(0));
+      auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
+      Stmt check = AssertStmtNode::make(foldl(fand, const_true(1), conds), stride_msg,
+                                        EvaluateNode::make(0));
       check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
       asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
     }
index ae5e673..c405b1f 100644 (file)
@@ -26,7 +26,6 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
-#include "../../arith/compute_expr.h"
 #include "ir_util.h"
 
 namespace tvm {
@@ -115,8 +114,9 @@ class DoubleBufferInjector : public StmtExprMutator {
   Stmt VisitStmt_(const AllocateNode* op) final {
     auto it = dbuffer_info_.find(op->buffer_var.get());
     if (it != dbuffer_info_.end()) {
+      auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
       it->second.stride =
-          arith::ComputeReduce<MulNode>(op->extents, PrimExpr()) * op->dtype.lanes();
+          foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes();
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       op = stmt.as<AllocateNode>();
       Array<PrimExpr> new_extents{make_const(op->extents[0].dtype(), 2)};
index 834a7e9..e2a027d 100644 (file)
@@ -27,7 +27,6 @@
 
 #include <unordered_set>
 
-#include "../../arith/compute_expr.h"
 #include "ir_util.h"
 
 namespace tvm {
@@ -368,7 +367,9 @@ class VTInjector : public StmtExprMutator {
     // always rewrite if not allow sharing.
     if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
       // place v on highest dimension.
-      PrimExpr stride = arith::ComputeReduce<MulNode>(op->extents, PrimExpr()) * op->dtype.lanes();
+      auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
+      PrimExpr stride =
+          foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes();
       Array<PrimExpr> other;
       other.push_back(make_const(op->extents[0].dtype(), num_threads_));
       for (PrimExpr e : extents) {
index de86647..7f9a329 100644 (file)
@@ -30,7 +30,6 @@
 
 #include <unordered_set>
 
-#include "../../arith/compute_expr.h"
 #include "../../runtime/thread_storage_scope.h"
 #include "ir_util.h"
 
index 91879b6..fb86bc2 100644 (file)
 #include <tvm/target/target.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
 #include <unordered_set>
 
-#include "../../arith/compute_expr.h"
 #include "../../arith/pattern_match.h"
 #include "../../runtime/thread_storage_scope.h"
 
index cd749b9..3a42137 100644 (file)
 #include "storage_access.h"
 
 #include <tvm/target/target_info.h>
+#include <tvm/tir/op.h>
 
 #include <string>
 #include <utility>
 
-#include "../../arith/compute_expr.h"
 #include "ir_util.h"
 
 namespace tvm {
index 646e008..1e656ce 100644 (file)
@@ -37,7 +37,6 @@
 
 #include <unordered_map>
 
-#include "../../arith/compute_expr.h"
 #include "../../arith/ir_visitor_with_analyzer.h"
 #include "../../runtime/thread_storage_scope.h"
 #include "arg_binder.h"
index fc86f2b..365ff75 100644 (file)
@@ -34,7 +34,6 @@
 #include <unordered_map>
 #include <unordered_set>
 
-#include "../../arith/compute_expr.h"
 #include "../../runtime/thread_storage_scope.h"
 #include "ir_util.h"
 
@@ -555,10 +554,12 @@ class StoragePlanRewriter : public StmtExprMutator {
             alloc_type = op->dtype;
           }
         }
+
+        auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
+
         if (e->allocs.size() == 1) {
           // simply use the original allocation.
-          PrimExpr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents,
-                                                      make_const(DataType::Int(32), 1));
+          PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), e->allocs[0]->extents);
           e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition,
                                             EvaluateNode::make(0));
           if (e->scope.tag.length() != 0) {
@@ -571,8 +572,7 @@ class StoragePlanRewriter : public StmtExprMutator {
           // Build a merged allocation
           PrimExpr combo_size;
           for (const AllocateNode* op : e->allocs) {
-            PrimExpr sz =
-                arith::ComputeReduce<MulNode>(op->extents, make_const(DataType::Int(32), 1));
+            PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), op->extents);
             auto nbits = op->dtype.bits() * op->dtype.lanes();
             if (const auto* imm = sz.as<IntImmNode>()) {
               if (imm->value > std::numeric_limits<int>::max() / nbits) {
index 4ccfbc3..fd1a92a 100644 (file)
@@ -33,7 +33,6 @@
 #include <unordered_set>
 #include <vector>
 
-#include "../../arith/compute_expr.h"
 #include "ir_util.h"
 
 namespace tvm {
index 9e553cb..91993ac 100644 (file)
@@ -24,6 +24,7 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
@@ -31,8 +32,6 @@
 #include <unordered_set>
 #include <vector>
 
-#include "../../arith/compute_expr.h"
-
 namespace tvm {
 namespace tir {
 
@@ -109,8 +108,14 @@ class Vectorizer : public StmtExprMutator {
     }
   }
 
-  PrimExpr VisitExpr_(const AddNode* op) final { return AddSubVec(op); }
-  PrimExpr VisitExpr_(const SubNode* op) final { return AddSubVec(op); }
+  PrimExpr VisitExpr_(const AddNode* op) final {
+    return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
+  }
+
+  PrimExpr VisitExpr_(const SubNode* op) final {
+    return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; });
+  }
+
   PrimExpr VisitExpr_(const MulNode* op) final {
     PrimExpr a = this->VisitExpr(op->a);
     PrimExpr b = this->VisitExpr(op->b);
@@ -423,8 +428,8 @@ class Vectorizer : public StmtExprMutator {
       return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
     }
   }
-  template <typename T>
-  PrimExpr AddSubVec(const T* op) {
+  template <typename T, typename FCompute>
+  PrimExpr AddSubVec(const T* op, FCompute fcompute) {
     PrimExpr a = this->VisitExpr(op->a);
     PrimExpr b = this->VisitExpr(op->b);
     if (a.same_as(op->a) && b.same_as(op->b)) {
@@ -435,12 +440,12 @@ class Vectorizer : public StmtExprMutator {
         const RampNode* b_ramp = b.as<RampNode>();
         const RampNode* a_ramp = a.as<RampNode>();
         if (a.dtype().lanes() == 1 && b_ramp) {
-          return RampNode::make(
-              arith::Compute<T>(a, b_ramp->base),
-              arith::Compute<T>(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes);
+          return RampNode::make(fcompute(a, b_ramp->base),
+                                fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride),
+                                b_ramp->lanes);
         }
         if (b.dtype().lanes() == 1 && a_ramp) {
-          return RampNode::make(arith::Compute<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
+          return RampNode::make(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
         }
       }
       return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));