Add `SkipVectorize` pass (#3222)
authorLogan Weber <36520469+weberlo@users.noreply.github.com>
Tue, 21 May 2019 23:34:35 +0000 (16:34 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Tue, 21 May 2019 23:34:35 +0000 (16:34 -0700)
docs/api/python/dev.rst
include/tvm/build_module.h
include/tvm/ir_pass.h
python/tvm/build_module.py
src/codegen/build_module.cc
src/pass/vectorize_loop.cc

index e4b207bf4cbc52f912fc35c12c2e0c4bf4cb3d13..7bb938ca751732b87c6b20d4a21855e61474e2dc 100644 (file)
@@ -61,6 +61,7 @@ tvm.ir_pass
    tvm.ir_pass.CanonicalSimplify
    tvm.ir_pass.StorageFlatten
    tvm.ir_pass.VectorizeLoop
+   tvm.ir_pass.SkipVectorize
    tvm.ir_pass.UnrollLoop
    tvm.ir_pass.ThreadSync
    tvm.ir_pass.StorageRewrite
index 208f086f86c0bfc079080842ff6973477e27f66f..7fb456c823a7df74db0b92f9a78dbcde9113836c 100644 (file)
@@ -246,6 +246,9 @@ class BuildConfigNode : public Node {
   /*! \brief Whether to disable select rewriting. */
   bool disable_select_rewriting = false;
 
+  /*! \brief Whether to disable loop vectorization. */
+  bool disable_vectorize = false;
+
   void VisitAttrs(AttrVisitor* v) final {
     v->Visit("data_alignment", &data_alignment);
     v->Visit("offset_factor", &offset_factor);
@@ -260,6 +263,7 @@ class BuildConfigNode : public Node {
     v->Visit("dump_pass_ir", &dump_pass_ir);
     v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
     v->Visit("disable_select_rewriting", &disable_select_rewriting);
+    v->Visit("disable_vectorize", &disable_vectorize);
   }
 
   static constexpr const char* _type_key = "BuildConfig";
index 5ef4dc4ed9d704565ab04ad50b0e117e7639c57f..e1c92e50e6ad15c173eff8088773dd33dadc5d78 100644 (file)
@@ -250,35 +250,42 @@ Stmt UnrollLoop(Stmt stmt,
 
 /*!
  * \brief vectorize the constant loops
- * \param stmt The statment to be vectorized.
+ * \param stmt The statement to be vectorized.
  * \return Transformed stmt.
  */
 Stmt VectorizeLoop(Stmt stmt);
 
+/*!
+ * \brief convert vectorized loops into serialized loops
+ * \param stmt The statement to skip vectorization on.
+ * \return Transformed stmt.
+ */
+Stmt SkipVectorize(Stmt stmt);
+
 /*!
 * \brief instruments bound checkers.
-* \param stmt The statment to be instrumented.
-* \return Instrumented Stmt.
+* \param stmt The statement to be instrumented.
+* \return Instrumented stmt.
 */
 Stmt InstrumentBoundCheckers(Stmt stmt);
 
 /*!
  * \brief Inject virtual thread loops into stmt.
- * \param stmt The statment to be transformed.
+ * \param stmt The statement to be transformed.
  * \return Transformed stmt.
  */
 Stmt InjectVirtualThread(Stmt stmt);
 
 /*!
  * \brief Inject prefetch instructions into stmt.
- * \param stmt The statment to be transformed.
+ * \param stmt The statement to be transformed.
  * \return Transformed stmt.
  */
 Stmt InjectPrefetch(Stmt stmt);
 
 /*!
  * \brief Inject double buffer into stmt.
- * \param stmt The statment to be transformed.
+ * \param stmt The statement to be transformed.
  * \param split_loop Loop splitting factor.
  * \return Transformed stmt.
  */
@@ -287,7 +294,7 @@ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
 /*!
  * \brief Inject copy intrinsics with optional pad.
  *
- * \param stmt The statment to be transformed.
+ * \param stmt The statement to be transformed.
  * \param pragma_key The pragma key for hint of copy.
  * \param fintrin The function with signature
  *
@@ -308,7 +315,7 @@ Stmt InjectCopyIntrin(Stmt stmt,
  *  Trying to share space between allocations to make
  *  a static allocation plan when possible.
  *
- * \param stmt The stmt to be trasnformed
+ * \param stmt The stmt to be transformed
  * \return Transformed stmt.
  */
 Stmt StorageRewrite(Stmt stmt);
@@ -324,7 +331,7 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop);
 /*!
  * \brief Detect and insert sync points to co-processor.
  *
- * \param stmt The stmt to be trasnformed
+ * \param stmt The stmt to be transformed
  * \return Transformed stmt.
  */
 Stmt CoProcSync(Stmt stmt);
@@ -332,7 +339,7 @@ Stmt CoProcSync(Stmt stmt);
 /*!
  * \brief Lift common attrs with attr_key to outer scope.
  *
- * \param stmt The stmt to be trasnformed
+ * \param stmt The stmt to be transformed
  * \param attr_key The attribute key to be checked.
  * \return Transformed stmt.
  */
@@ -340,7 +347,7 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
 
 /*!
  * \brief Detect and rewrite unsafe select that contains memory access.
- * \param stmt The statment to be rewritten.
+ * \param stmt The statement to be rewritten.
  * \return Transformed stmt.
  */
 Stmt RewriteUnsafeSelect(Stmt stmt);
@@ -349,7 +356,7 @@ Stmt RewriteUnsafeSelect(Stmt stmt);
  * \brief Lower attached storage access information.
  * Do this pass after all storage access analysis finish.
  *
- * \param stmt The stmt to be trasnformed
+ * \param stmt The stmt to be transformed
  * \return Transformed stmt.
  */
 Stmt LowerStorageAccessInfo(Stmt stmt);
@@ -358,7 +365,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt);
  * \brief Decorate the stmt with a device scope, this is helpful for
  * hardware accelerator without thread blocks.
  *
- * \param stmt The stmt to be trasnformed
+ * \param stmt The stmt to be transformed
  * \return Transformed stmt.
  */
 Stmt DecorateDeviceScope(Stmt stmt);
@@ -381,7 +388,7 @@ Stmt DecorateDeviceScope(Stmt stmt);
  * \return a LoweredFunc with the specified signiture.
  *
  * \note
- *  The function signiture have two cases
+ *  The function signature have two cases
  *
  *  let num_packed_args = len(api_args) - num_unpacked_args;
  *
index 120bf629a9592bdd8a1b82b5431de0df9eac40bf..a28ab98fb60e7964ea902253f10d6be561cd34ec 100644 (file)
@@ -143,7 +143,8 @@ class BuildConfig(NodeBase):
         "double_buffer_split_loop": 1,
         "dump_pass_ir": False,
         "instrument_bound_checkers": False,
-        "disable_select_rewriting": False
+        "disable_select_rewriting": False,
+        "disable_vectorize": False
     }
     _dump_ir = DumpIR()
 
@@ -384,7 +385,10 @@ def lower(sch,
     # Phase 2
     if not simple_mode:
         stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
-    stmt = ir_pass.VectorizeLoop(stmt)
+    if cfg.disable_vectorize:
+        stmt = ir_pass.SkipVectorize(stmt)
+    else:
+        stmt = ir_pass.VectorizeLoop(stmt)
     stmt = ir_pass.InjectVirtualThread(stmt)
     stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
     stmt = ir_pass.StorageRewrite(stmt)
index 9b30ced90c4fffacdb01ee01cf2a9c9bfd2572d6..ac6b797d9683660a2a295c6665e8cc758ed8f058 100644 (file)
@@ -392,7 +392,11 @@ Stmt BuildStmt(Schedule sch,
   if (loop_partition) {
     stmt = ir::LoopPartition(stmt, config->partition_const_loop);
   }
-  stmt = ir::VectorizeLoop(stmt);
+  if (config->disable_vectorize) {
+    stmt = ir::SkipVectorize(stmt);
+  } else {
+    stmt = ir::VectorizeLoop(stmt);
+  }
   stmt = ir::InjectVirtualThread(stmt);
   stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
   stmt = ir::StorageRewrite(stmt);
@@ -642,6 +646,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
   p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
   p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
+  p->stream << "disable_vectorize=" << op->disable_vectorize;
   p->stream << ")";
 });
 
index f87e80c2d0303e625059b03854b84ca237098fdf..8c3d383c1529af22d66eba0898b36fbbf85237d8 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -519,5 +519,23 @@ Stmt VectorizeLoop(Stmt stmt) {
   return LoopVectorizer().Mutate(stmt);
 }
 
+class VectorizeSkipper : public IRMutator {
+ public:
+  Stmt Mutate_(const For* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<For>();
+    if (op->for_type == ForType::Vectorized) {
+      return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
+                       op->body);
+    } else {
+       return stmt;
+    }
+  }
+};
+
+Stmt SkipVectorize(Stmt stmt) {
+  return VectorizeSkipper().Mutate(stmt);
+}
+
 }  // namespace ir
 }  // namespace tvm