[TIR][Bugfix] Improved massive build times caused by tir.floormod and tir.floordiv...
authordprankratz <65860457+dprankratz@users.noreply.github.com>
Tue, 28 Jul 2020 16:06:06 +0000 (10:06 -0600)
committerGitHub <noreply@github.com>
Tue, 28 Jul 2020 16:06:06 +0000 (09:06 -0700)
* Improved uncommon case of floormod and floordiv. Removed dependence on np floor_div and fmod.

* Fixed clang-format complaints

* Streamlined floormod and floordiv lowering logic

* Improved build times by expressing int64 case of tir FloorMod and FloorDiv using let nodes

* Updated use-def analysis and llvm codegen to support duplicated letnodes.

* Corrected misuse of var_map_ in llvm codegen

* Updated backends that support LetNode

* Changed floormod and div lowering logic to avoid using FP on systems that don't support it.

* Fixed formatting

Co-authored-by: pankratz <pankratz@ualberta.ca>
src/target/llvm/codegen_llvm.cc
src/target/llvm/codegen_llvm.h
src/target/source/codegen_c.cc
src/target/source/codegen_c.h
src/target/spirv/codegen_spirv.cc
src/target/spirv/codegen_spirv.h
src/tir/transforms/lower_intrin.cc
src/tir/transforms/split_host_device.cc
topi/tests/python/test_topi_broadcast.py

index cb28b81..225d225 100644 (file)
@@ -1018,7 +1018,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
-  CHECK(!var_map_.count(op->var.get()));
+  auto it = let_binding_.find(op->var);
+  if (it != let_binding_.end()) {
+    CHECK(deep_equal_(it->second->value, op->value))
+        << "Let cannot bind the same var to two different values";
+  } else {
+    let_binding_[op->var] = op;
+  }
   var_map_[op->var.get()] = MakeValue(op->value);
   analyzer_->Bind(op->var, op->value);
   return MakeValue(op->body);
index e20c8e1..ce5baba 100644 (file)
@@ -29,6 +29,7 @@
 #include <tvm/ir/module.h>
 #include <tvm/runtime/container.h>
 #include <tvm/target/codegen.h>
+#include <tvm/tir/analysis.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/op.h>
@@ -321,6 +322,10 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
   std::unordered_set<const VarNode*> alias_var_set_;
   // set of volatile buffer.
   std::unordered_set<const VarNode*> volatile_buf_;
+  // deep comparison of PrimExpr
+  ExprDeepEqual deep_equal_;
+  // binding of let variables. Enables duplicate var defs that map to same value
+  std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
   // Cache potential common path ops to slightly improve lookup time.
   // global symbol table.
   OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
index 1530892..3e6838c 100644 (file)
@@ -761,8 +761,14 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
 }
 
 void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) {  // NOLINT(*)
+  auto it = let_binding_.find(op->var);
+  if (it != let_binding_.end()) {
+    CHECK(deep_equal_(it->second->value, op->value))
+        << "Let cannot bind the same var to two different values";
+  } else {
+    let_binding_[op->var] = op;
+  }
   std::string value = PrintExpr(op->value);
-  CHECK(!var_idmap_.count(op->var.get()));
   var_idmap_[op->var.get()] = value;
   os << PrintExpr(op->body);
 }
index 87a4a29..c1b566c 100644 (file)
@@ -27,6 +27,7 @@
 #include <tvm/ir/op.h>
 #include <tvm/runtime/container.h>
 #include <tvm/target/codegen.h>
+#include <tvm/tir/analysis.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/function.h>
@@ -269,6 +270,10 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
   bool print_ssa_form_{false};
   /*! \brief set of volatile buf access */
   std::unordered_set<const VarNode*> volatile_buf_;
+  // deep comparison of PrimExpr
+  ExprDeepEqual deep_equal_;
+  // binding of let variables. Enables duplicate var defs that map to same value
+  std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
 };
 
 }  // namespace codegen
index 7ff0c55..2a67d95 100644 (file)
@@ -230,7 +230,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
 }
 
 spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
-  CHECK(!var_map_.count(op->var.get()));
+  auto it = let_binding_.find(op->var);
+  if (it != let_binding_.end()) {
+    CHECK(deep_equal_(it->second->value, op->value))
+        << "Let cannot bind the same var to two different values";
+  } else {
+    let_binding_[op->var] = op;
+  }
   var_map_[op->var.get()] = MakeValue(op->value);
   analyzer_->Bind(op->var, op->value);
   return MakeValue(op->body);
index a8af29a..9bf8109 100644 (file)
@@ -25,6 +25,7 @@
 #define TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_
 
 #include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/stmt_functor.h>
@@ -140,6 +141,10 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
   std::unordered_map<const VarNode*, spirv::Value> var_map_;
   // The analyzer.
   std::unique_ptr<arith::Analyzer> analyzer_;
+  // deep comparison of PrimExpr
+  ExprDeepEqual deep_equal_;
+  // binding of let variables. Enables duplicate var defs that map to same value
+  std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
 };
 
 }  // namespace codegen
index 1c529d8..f3fe945 100644 (file)
@@ -112,14 +112,22 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
         }
       }
     } else {
-      // uncommon case
-      DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
-      // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
-      // b < 0  => (rmod <= 0 ? rdiv : rdiv - 1)
-      PrimExpr rdiv = truncdiv(op->a, op->b);
-      PrimExpr rmod = truncmod(op->a, op->b);
-      return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
-                         rdiv - make_const(dtype, 1));
+      if (dtype.is_float()) {
+        // floor(a / b)
+        return VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>());
+      } else {
+        // uncommon case
+        DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
+        auto rmod = tir::Var("rmod", dtype);
+        auto rdiv = tir::Var("rdiv", dtype);
+        // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
+        // b < 0  => (rmod <= 0 ? rdiv : rdiv - 1)
+        PrimExpr let_rdiv =
+            tir::Let(rdiv, truncdiv(op->a, op->b),
+                     tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
+                                 rdiv - make_const(dtype, 1)));
+        return Let(rmod, truncmod(op->a, op->b), let_rdiv);
+      }
     }
   }
 
@@ -158,14 +166,21 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
         }
       }
     } else {
-      // uncommon case
-      DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
-      PrimExpr rmod = truncmod(op->a, op->b);
-      // b > 0 && rmod >= 0 -> rmod
-      // b > 0 && rmod < 0  -> rmod + b
-      // b < 0 && rmod < 0 -> rmod
-      // b < 0 && rmod > 0 -> rmod + b
-      return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b);
+      if (dtype.is_float()) {
+        // a - floor(a / b) * b
+        return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()) * op->b);
+      } else {
+        // uncommon case
+        DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
+        auto rmod = tir::Var("rmod", dtype);
+        // b > 0 && rmod >= 0 -> rmod
+        // b > 0 && rmod < 0  -> rmod + b
+        // b < 0 && rmod < 0 -> rmod
+        // b < 0 && rmod > 0 -> rmod + b
+        return Let(
+            rmod, truncmod(op->a, op->b),
+            Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b));
+      }
     }
   }
 
index 169ac14..d5b51cb 100644 (file)
@@ -99,14 +99,28 @@ class VarUseDefAnalysis : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const LetNode* op) final {
-    this->HandleDef(op->var.get());
+    // Weaker SSA condition
+    // A single var can be binded in multiple lets
+    // but they have to bind to the same value.
+    // This is used to allow cases when we reuse a single let
+    // expression to construct a nested expr.
+    // (let x = 1 in x + 1) * (let x = 1 in x + 1)
+    auto it = let_binding_.find(op->var);
+    PrimExpr value = this->VisitExpr(op->value);
+    if (it != let_binding_.end()) {
+      CHECK(deep_equal_(it->second->value, value))
+          << "Let cannot bind the same var to two different values";
+      return GetRef<PrimExpr>(it->second);
+    } else {
+      this->HandleDef(op->var.get());
+      let_binding_[op->var] = op;
+    }
     PrimExpr body = this->VisitExpr(op->body);
     // eliminate unreferenced let
     if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
         simplify_let_) {
       return body;
     } else {
-      PrimExpr value = this->VisitExpr(op->value);
       if (body.same_as(op->body) && value.same_as(op->value)) {
         return GetRef<PrimExpr>(op);
       } else {
@@ -157,6 +171,10 @@ class VarUseDefAnalysis : public StmtExprMutator {
   Array<PrimExpr> thread_extent_;
   std::unordered_map<const VarNode*, int> use_count_;
   std::unordered_map<const VarNode*, int> def_count_;
+
+ private:
+  ExprDeepEqual deep_equal_;
+  std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
 };
 
 Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
index 27b66e0..f3e0300 100644 (file)
@@ -90,19 +90,6 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
         rhs_npy, rhs_nd = gen_operand(rhs_shape, rhs_min, rhs_max, ctx)
         out_npy = fnumpy(lhs_npy, rhs_npy)
 
-        if fnumpy == np.floor_divide:
-            # avoid check too close to X.5 and X.0
-            # FIXME: floor_divide(94.90735, 0.6731018) behaves as floor(div(94.90735, 0.6731018))
-            # However the result is somehow incorrect - need to further investigate.
-            # And looks like numpy's floor_div(a,b) is implemented different from floor(div(a,b))
-            mask = np.logical_or(np.abs(np.abs(np.fmod(lhs_npy / rhs_npy, 1)) - 0.5) < 1e-6,
-                                 np.abs(np.fmod(lhs_npy / rhs_npy, 1)) < 1e-6)
-            if mask.any():
-                lhs_npy = lhs_npy + mask * 1e-3  * rhs_npy
-                lhs_npy = lhs_npy.astype(dtype)
-                lhs_nd = tvm.nd.array(lhs_npy, ctx) if lhs_shape is not None else lhs_npy.item()
-                out_npy = fnumpy(lhs_npy, rhs_npy)
-
         out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
         foo(lhs_nd, rhs_nd, out_nd)
         tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
@@ -151,12 +138,14 @@ def test_divide():
         (2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
 
 def test_floor_divide():
+    def _canonical_floor_div(a,b):
+        return np.floor(a / b)
     verify_broadcast_binary_ele(
-        None, (10,), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
+        None, (10,), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
     verify_broadcast_binary_ele(
-        (), None, topi.floor_divide, np.floor_divide, rhs_min=0.0001)
+        (), None, topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
     verify_broadcast_binary_ele(
-        (2, 3, 64, 32), (64, 32), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
+        (2, 3, 64, 32), (64, 32), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
 
 def test_maximum_minmum():
     verify_broadcast_binary_ele(
@@ -175,10 +164,12 @@ def test_mod():
         (1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
 
 def test_floor_mod():
+    def _canonical_floor_mod(a,b):
+        return a - np.floor(a / b) * b
     verify_broadcast_binary_ele(
-        (1, 2, 2), (2,), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="int32")
+        (1, 2, 2), (2,), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="int32")
     verify_broadcast_binary_ele(
-        (3, 4, 5), (3, 4, 5), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="float32")
+        (3, 4, 5), (3, 4, 5), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="float32")
 
 def test_cmp():
     # explicit specify the output type