[ARITH] Refactor: Remove un-necessary usage of ComputeExpr (#3503)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sat, 6 Jul 2019 21:44:46 +0000 (14:44 -0700)
committerGitHub <noreply@github.com>
Sat, 6 Jul 2019 21:44:46 +0000 (14:44 -0700)
16 files changed:
src/arithmetic/compute_expr.h
src/arithmetic/detect_linear_equation.cc
src/codegen/codegen_cuda.cc
src/codegen/llvm/codegen_llvm.cc
src/codegen/spirv/codegen_spirv.cc
src/lang/buffer.cc
src/pass/arg_binder.cc
src/pass/inject_double_buffer.cc
src/pass/inject_virtual_thread.cc
src/pass/lower_thread_allreduce.cc
src/pass/lower_warp_memory.cc
src/pass/make_api.cc
src/pass/storage_flatten.cc
src/pass/unroll_loop.cc
src/pass/vectorize_loop.cc
src/schedule/message_passing.cc

index cc54bff..4fa5fe9 100644 (file)
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file compute_expr.h
- * \brief Utility integer expression with quick eager simplification.
- *  This is weaker than Simplify but can be done Eagerly.
+ * \brief Utility to invoke certan compute operations.
  */
 #ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
 #define TVM_ARITHMETIC_COMPUTE_EXPR_H_
@@ -41,7 +39,7 @@ namespace arith {
  * \return The result.
  */
 template<typename OP>
-inline Expr ComputeExpr(Expr lhs, Expr rhs) {
+inline Expr Compute(Expr lhs, Expr rhs) {
   return OP::make(lhs, rhs);
 }
 
@@ -79,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) {
 }
 
 template<>
-inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
+inline Expr Compute<ir::Add>(Expr a, Expr b) {
   return a + b;
 }
 
 template<>
-inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
+inline Expr Compute<ir::Sub>(Expr a, Expr b) {
   return a - b;
 }
 
 template<>
-inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
+inline Expr Compute<ir::Mul>(Expr a, Expr b) {
   return a * b;
 }
 
 template<>
-inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
+inline Expr Compute<ir::Div>(Expr a, Expr b) {
   return a / b;
 }
 
 template<>
-inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
+inline Expr Compute<ir::Mod>(Expr a, Expr b) {
   return a % b;
 }
 
 template<>
-inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
+inline Expr Compute<ir::Max>(Expr a, Expr b) {
   return max(a, b);
 }
 
 template<>
-inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
+inline Expr Compute<ir::Min>(Expr a, Expr b) {
   return min(a, b);
 }
 
@@ -121,7 +119,7 @@ inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
   }
   Expr res = values[0];
   for (size_t i = 1; i < values.size(); ++i) {
-    res = ComputeExpr<Op>(res, values[i]);
+    res = Compute<Op>(res, values[i]);
   }
   return res;
 }
index e584c8b..3c5f12a 100644 (file)
@@ -27,7 +27,6 @@
 #include <tvm/ir_visitor.h>
 #include <tvm/ir_functor_ext.h>
 #include <tvm/arithmetic.h>
-#include "compute_expr.h"
 
 namespace tvm {
 namespace arith {
@@ -127,18 +126,18 @@ class LinearEqDetector
   Expr AddCombine(Expr a, Expr b) {
     if (!a.defined()) return b;
     if (!b.defined()) return a;
-    return ComputeExpr<Add>(a, b);
+    return a + b;
   }
   Expr SubCombine(Expr a, Expr b) {
     // Check b first in case they are both undefined
     if (!b.defined()) return a;
     if (!a.defined()) return -b;
-    return ComputeExpr<Sub>(a, b);
+    return a - b;
   }
   Expr MulCombine(Expr a, Expr b) {
     if (!a.defined()) return a;
     if (!b.defined()) return b;
-    return ComputeExpr<Mul>(a, b);
+    return a * b;
   }
 };
 
index 22dde1c..a324731 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -27,7 +27,6 @@
 #include <vector>
 #include <string>
 #include "codegen_cuda.h"
-#include "../arithmetic/compute_expr.h"
 
 namespace tvm {
 namespace codegen {
index 1e56583..fde0486 100644 (file)
@@ -748,9 +748,7 @@ void CodeGenLLVM::Scalarize(const Expr& e,
                             std::function<void(int i, llvm::Value* v)> f) {
   if (const Ramp* ramp = e.as<Ramp>()) {
     for (int i = 0; i < ramp->type.lanes(); ++i) {
-      Expr offset = arith::ComputeExpr<Add>(
-          ramp->base,
-          arith::ComputeExpr<Mul>(ramp->stride, i));
+      Expr offset = ramp->base + (ramp->stride * i);
       f(i, MakeValue(offset));
     }
   } else {
index fd113ca..7686250 100644 (file)
@@ -25,8 +25,8 @@
 #include <tvm/ir.h>
 #include <tvm/ir_pass.h>
 #include <string>
-#include "../../arithmetic/compute_expr.h"
 #include "codegen_spirv.h"
+#include "../../arithmetic/compute_expr.h"
 
 namespace tvm {
 namespace codegen {
@@ -339,7 +339,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
     spirv::Value v = base;
     if (i != 0) {
       spirv::Value offset = MakeValue(
-          arith::ComputeExpr<Mul>(make_const(op->stride.type(), i), op->stride));
+          make_const(op->stride.type(), i) * op->stride);
       v = builder_->Add(v, offset);
     }
     values.push_back(v);
@@ -419,9 +419,7 @@ void CodeGenSPIRV::Scalarize(const Expr& e,
                              std::function<void(int i, spirv::Value v)> f) {
   if (const Ramp* ramp = e.as<Ramp>()) {
     for (int i = 0; i < ramp->type.lanes(); ++i) {
-      Expr offset = arith::ComputeExpr<Add>(
-          ramp->base,
-          arith::ComputeExpr<Mul>(ramp->stride, i));
+      Expr offset = ramp->base + ramp->stride * i;
       f(i, MakeValue(offset));
     }
   } else {
index 573ecff..cb5c867 100644 (file)
@@ -378,8 +378,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
     extent = make_const(self->DefaultIndexType(), 1);
   } else if (self->strides.size() == self->shape.size()) {
     int highest_dim = 0;
-    extent = arith::ComputeExpr<ir::Mul>(
-        self->strides[highest_dim], self->shape[highest_dim]) - offset;
+    extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
   } else {
     extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
   }
index d93d088..ff4c77a 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file arg_binder.cc
  * \brief Helper utility to match and bind arguments.
  */
index 94b4ab3..027639c 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -26,6 +26,7 @@
 #include <tvm/ir_pass.h>
 #include <tvm/ir_visitor.h>
 #include <tvm/ir_mutator.h>
+#include <tvm/expr_operator.h>
 #include "ir_util.h"
 #include "../arithmetic/compute_expr.h"
 
@@ -100,8 +101,8 @@ class DoubleBufferInjector : public IRMutator {
   Stmt Mutate_(const Allocate* op, const Stmt& s) final {
     auto it = dbuffer_info_.find(op->buffer_var.get());
     if (it != dbuffer_info_.end()) {
-      it->second.stride = arith::ComputeReduce<Mul>
-          (op->extents, Expr()) * op->type.lanes();
+      it->second.stride = arith::ComputeReduce<Mul>(
+          op->extents, Expr()) * op->type.lanes();
       Stmt stmt = IRMutator::Mutate_(op, s);
       op = stmt.as<Allocate>();
       Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
@@ -135,11 +136,11 @@ class DoubleBufferInjector : public IRMutator {
             << "It is better to split with multiple of 2";
         CHECK(is_zero(old_loop->min));
         Expr zero = old_loop->min;
-        Expr new_ext = arith::ComputeExpr<Sub>(
-            old_loop->extent, make_const(old_loop->loop_var.type(), 1));
+        Expr new_ext =
+            old_loop->extent - make_const(old_loop->loop_var.type(), 1);
         Expr factor = make_const(new_ext.type(), split_loop_);
-        Expr outer_ext = arith::ComputeExpr<Div>(new_ext, factor);
-        Expr tail_base = arith::ComputeExpr<Mul>(outer_ext, factor);
+        Expr outer_ext = new_ext / factor;
+        Expr tail_base = outer_ext * factor;
         Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type());
         std::unordered_map<const Variable*, Expr> vmap;
         std::vector<Stmt> loop_seq;
index 9009416..88e7f43 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file inject_virtual_thread.cc
  */
 #include <tvm/ir.h>
@@ -37,6 +36,7 @@ class ExprTouched final : public IRVisitor {
   explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
                        bool check_write)
       : touched_var_(touched), check_write_(check_write) {}
+
   void Visit(const NodeRef& n) final {
     // early stopping
     if (expr_touched_ && !check_write_) return;
@@ -241,8 +241,8 @@ class VTInjector : public IRMutator {
       visit_touched_var_ = true;
       Expr offset = Mutate(op->args[2]);
       Expr extent = Mutate(op->args[3]);
-      Expr stride = arith::ComputeExpr<Div>(
-          it->second, make_const(offset.type(), dtype.lanes()));
+      Expr stride =
+          it->second / make_const(offset.type(), dtype.lanes());
       offset = stride * var_ + offset;
       return Call::make(
           op->type, op->name,
index d0490b2..02c72d0 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
index 7d9d486..bb7260f 100644 (file)
@@ -18,8 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
- *
  * Lower warp memory to use local memory
  * and shuffle intrinsics.
  *
index 13f46ec..0109ad1 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -33,7 +33,6 @@
 
 #include "ir_util.h"
 #include "arg_binder.h"
-#include "../arithmetic/compute_expr.h"
 
 namespace tvm {
 namespace ir {
index ff6b416..19e7a32 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator {
             stride = ir::Simplify(stride);
           }
           rstrides.push_back(stride);
-          stride = arith::ComputeExpr<Mul>(stride, shape[dim]);
+          stride = stride * shape[dim];
         }
         strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
       }
@@ -237,7 +237,7 @@ class StorageFlattener : public IRMutator {
         int first_dim = 0;
         ret = Allocate::make(
             e.buffer->data, storage_type,
-            {arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])},
+            {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
             make_const(Bool(e.buffer->dtype.lanes()), true), body);
       } else {
         shape = e.buffer->shape;
@@ -414,8 +414,7 @@ class StorageFlattener : public IRMutator {
     if (be.bounds.size() != 0) {
       CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
       for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
-        begins.push_back(
-            arith::ComputeExpr<Sub>(tuple->args[2 * i], be.bounds[i]->min));
+        begins.push_back(tuple->args[2 * i] - be.bounds[i]->min);
         extents.push_back(tuple->args[2 * i + 1]);
       }
     } else {
index ead234e..7561308 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  *  Loop unrolling as in Halide pipeline.
  * \file unroll_loop.cc
  */
@@ -144,7 +143,6 @@ class LoopUnroller : public IRMutator {
   }
 
   Stmt Unroll(const For* op) {
-    using arith::ComputeExpr;
     int value = GetExtent(op);
     // For loop must have a constant integer extent
     CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
@@ -154,9 +152,7 @@ class LoopUnroller : public IRMutator {
     Stmt unrolled;
     for (int i = 0; i < value; ++i) {
       Var lv(op->loop_var.node_);
-      vmap.Set(lv,
-               ComputeExpr<Add>(
-                       op->min, make_const(op->loop_var.type(), i)));
+      vmap.Set(lv, op->min + make_const(op->loop_var.type(), i));
       Stmt step = Substitute(body, vmap);
       if (unrolled.defined()) {
         unrolled = Block::make(unrolled, step);
index a48e8b4..2d8416e 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file vectorize_loop.cc
  */
 // Loop vectorizer as in Halide pipeline.
@@ -486,13 +485,13 @@ class Vectorizer : public IRMutator {
         const Ramp* a_ramp = a.as<Ramp>();
         if (a.type().lanes() == 1 && b_ramp) {
           return Ramp::make(
-              arith::ComputeExpr<T>(a, b_ramp->base),
-              arith::ComputeExpr<T>(make_zero(b_ramp->stride.type()), b_ramp->stride),
+              arith::Compute<T>(a, b_ramp->base),
+              arith::Compute<T>(make_zero(b_ramp->stride.type()), b_ramp->stride),
               b_ramp->lanes);
         }
         if (b.type().lanes() == 1 && a_ramp) {
           return Ramp::make(
-              arith::ComputeExpr<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
+              arith::Compute<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
         }
       }
       return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
index 0dc82ab..12c5db7 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file message_passing.cc
  * \brief The message passing domain.
  */
@@ -32,12 +31,11 @@ namespace tvm {
 namespace schedule {
 
 using namespace ir;
-using namespace arith;
 
 void Update(std::unordered_map<IterVar, Range>* p_state,
             const IterVar& iv,
             Range r,
-            Analyzer* analyzer) {
+            arith::Analyzer* analyzer) {
   auto it = p_state->find(iv);
   if (it == p_state->end()) {
     (*p_state)[iv] = r;
@@ -145,8 +143,8 @@ void PassUpIndex(const Stage& stage,
       Expr factor = dom_map.at(s->inner)->extent;
       Expr outer_min = dom_map.at(s->outer)->min;
       Expr inner_min = dom_map.at(s->inner)->min;
-      state[s->outer] = ComputeExpr<Div>(value, factor);
-      state[s->inner] = ComputeExpr<Mod>(value, factor);
+      state[s->outer] = value / factor;
+      state[s->inner] = value % factor;
       // add min if they exist
       if (!is_zero(outer_min)) {
         state[s->outer] = state[s->outer] + outer_min;
@@ -189,8 +187,8 @@ void PassDownIndex(const Stage& stage,
       CHECK(is_zero(r->min));
       Expr parent = state.at(s->parent);
       Expr factor = r->extent;
-      state[s->outer] = ComputeExpr<Div>(parent, factor);
-      state[s->inner] = ComputeExpr<Mod>(parent, factor);
+      state[s->outer] = parent / factor;
+      state[s->inner] = parent % factor;
     } else if (const FuseNode* s = rel.as<FuseNode>()) {
       if (!state.count(s->inner) && !state.count(s->outer)) {
         CHECK(allow_missing);
@@ -240,7 +238,7 @@ void PassUpDomain(const SplitNode* s,
   CHECK(outer.defined());
   CHECK(inner.defined());
   CHECK(factor.defined());
-  *parent = EvalSet(
+  *parent = arith::EvalSet(
       s->outer->var * factor + s->inner->var + parent_min,
       {{s->outer, outer}, {s->inner, inner}});
 }
@@ -290,8 +288,8 @@ void PassUpDomain(const RebaseNode* s,
     return;
   }
   Expr parent_min = dom_map.at(s->parent)->min;
-  *parent = EvalSet(s->rebased->var + parent_min,
-                    {{s->rebased, rebased}});
+  *parent = arith::EvalSet(s->rebased->var + parent_min,
+                           {{s->rebased, rebased}});
 }
 
 void PassUpDomain(const Stage& stage,
@@ -476,7 +474,7 @@ std::vector<Expr> MakeBoundCheck(
     const std::unordered_map<IterVar, Expr>& value_map,
     bool skip_ivar_domain,
     const std::unordered_set<IterVar>& skip_iter) {
-  Analyzer analyzer;
+  arith::Analyzer analyzer;
 
   std::unordered_map<IterVar, bool> bound_state;
   for (IterVar iv : stage->leaf_iter_vars) {
@@ -496,7 +494,7 @@ std::vector<Expr> MakeBoundCheck(
     if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
     if (bound_state.at(iv)) {
       Range dom = dom_map.at(iv);
-      Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
+      Expr value = value_map.at(iv) - dom->min;
       Expr vmax = EvalSet(value, iset_dmap).max();
       if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) {
         preds.emplace_back(value < dom->extent);
@@ -508,7 +506,7 @@ std::vector<Expr> MakeBoundCheck(
     Range dom = dom_map.at(iv);
     CHECK(iv->dom.defined());
     if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
-      Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min);
+      Expr value = value_map.at(iv) - iv->dom->min;
       IntSet s = EvalSet(value, iset_dmap);
       Expr vmin = s.min();
       Expr vmax = s.max();