[CodeGen] Add build config option disable_assert to control whether to generate asser...
authorZhao Wu <wuzhaozju@gmail.com>
Fri, 15 Nov 2019 18:05:26 +0000 (02:05 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 15 Nov 2019 18:05:26 +0000 (10:05 -0800)
include/tvm/build_module.h
include/tvm/ir_pass.h
python/tvm/build_module.py
src/codegen/build_module.cc
src/codegen/codegen.cc
src/pass/skip_assert.cc [new file with mode: 0644]

index 7114a45..a83288c 100644 (file)
@@ -229,6 +229,9 @@ class BuildConfigNode : public Node {
   /*! \brief Whether to disable loop vectorization. */
   bool disable_vectorize = false;
 
+  /*! \brief Whether to disable assert stmt generation. */
+  bool disable_assert = false;
+
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("data_alignment", &data_alignment);
     v->Visit("offset_factor", &offset_factor);
@@ -244,6 +247,7 @@ class BuildConfigNode : public Node {
     v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
     v->Visit("disable_select_rewriting", &disable_select_rewriting);
     v->Visit("disable_vectorize", &disable_vectorize);
+    v->Visit("disable_assert", &disable_assert);
   }
 
   static constexpr const char* _type_key = "BuildConfig";
index 76d7d61..5c5c4bb 100644 (file)
@@ -564,6 +564,13 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
 LoweredFunc InferFragment(LoweredFunc f);
 
 /*!
+ * \brief skip assert stmt generation
+ * \param f The function to be transformed.
+ * \return Transformed function.
+ */
+LoweredFunc SkipAssert(LoweredFunc f);
+
+/*!
  * \brief Verify if memory accesses are legal for a specific target device type.
  *
  *  In the case that tgt is cuda, if not all workload is bound with
index 217318e..f96e283 100644 (file)
@@ -144,7 +144,8 @@ class BuildConfig(NodeBase):
         "dump_pass_ir": False,
         "instrument_bound_checkers": False,
         "disable_select_rewriting": False,
-        "disable_vectorize": False
+        "disable_vectorize": False,
+        "disable_assert": False
     }
     _dump_ir = DumpIR()
 
index 3f279f8..ac991d4 100644 (file)
@@ -672,6 +672,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
   p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
   p->stream << "disable_vectorize=" << op->disable_vectorize;
+  p->stream << "disable_assert=" << op->disable_assert;
   p->stream << ")";
 });
 
index ed9484b..4ea37ba 100644 (file)
@@ -26,6 +26,7 @@
 #include <tvm/ir_pass.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/module.h>
+#include <tvm/build_module.h>
 #include <dmlc/memory_io.h>
 #include <sstream>
 #include <iostream>
@@ -40,12 +41,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
   if (pos != std::string::npos) {
     mode = mode.substr(0, pos);
   }
+  Array<LoweredFunc> transformed_funcs;
+  for (const auto& x : funcs) {
+    if (BuildConfig::Current()->disable_assert) {
+      auto func = ir::SkipAssert(x);
+      transformed_funcs.push_back(func);
+    }
+  }
   std::string build_f_name = "codegen.build_" + mode;
   // the build function.
   const PackedFunc* bf = runtime::Registry::Get(build_f_name);
   CHECK(bf != nullptr)
       << "Target " << target << " is not enabled";
-  runtime::Module m = (*bf)(funcs, target);
+  runtime::Module m = transformed_funcs.empty() ?
+                      (*bf)(funcs, target) :
+                      (*bf)(transformed_funcs, target);
   return m;
 }
 
diff --git a/src/pass/skip_assert.cc b/src/pass/skip_assert.cc
new file mode 100644 (file)
index 0000000..5f310a6
--- /dev/null
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_mutator.h>
+
+namespace tvm {
+namespace ir {
+
+class AssertSkipper : public IRMutator {
+ public:
+  Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<AssertStmt>();
+    return op->body;
+  }
+};
+
+Stmt SkipAssert(Stmt stmt) {
+  return AssertSkipper().Mutate(stmt);
+}
+
+LoweredFunc SkipAssert(LoweredFunc f) {
+  auto n = make_node<LoweredFuncNode>(*f.operator->());
+  n->body = SkipAssert(f->body);
+  return LoweredFunc(n);
+}
+
+}  // namespace ir
+}  // namespace tvm