[Ansor][AutoTVM v2.0] Phase 1: Access Analyzer (#6103)
authorLianmin Zheng <lianminzheng@gmail.com>
Sat, 25 Jul 2020 12:07:17 +0000 (05:07 -0700)
committerGitHub <noreply@github.com>
Sat, 25 Jul 2020 12:07:17 +0000 (05:07 -0700)
* add access analyzer

* add test cases

* move header files and polish comments

* fix lint

* update

* fix lint

* address comments

* fix lint

24 files changed:
include/tvm/auto_scheduler/auto_schedule.h [moved from src/auto_scheduler/auto_schedule.h with 81% similarity]
include/tvm/auto_scheduler/compute_dag.h [new file with mode: 0644]
include/tvm/auto_scheduler/loop_state.h [moved from src/auto_scheduler/loop_state.h with 91% similarity]
include/tvm/auto_scheduler/measure.h [moved from src/auto_scheduler/measure.h with 93% similarity]
include/tvm/auto_scheduler/measure_record.h [moved from src/auto_scheduler/measure_record.h with 83% similarity]
include/tvm/auto_scheduler/search_policy.h [moved from src/auto_scheduler/search_policy/search_policy.h with 79% similarity]
include/tvm/auto_scheduler/search_task.h [moved from src/auto_scheduler/search_task.h with 97% similarity]
include/tvm/auto_scheduler/transform_step.h [moved from src/auto_scheduler/transform_step.h with 98% similarity]
python/tvm/auto_scheduler/auto_schedule.py
python/tvm/auto_scheduler/workload_registry.py
src/auto_scheduler/auto_schedule.cc
src/auto_scheduler/compute_dag.cc
src/auto_scheduler/compute_dag.h [deleted file]
src/auto_scheduler/loop_state.cc
src/auto_scheduler/measure.cc
src/auto_scheduler/measure_record.cc
src/auto_scheduler/search_policy/empty_policy.cc
src/auto_scheduler/search_policy/empty_policy.h
src/auto_scheduler/search_policy/search_policy.cc
src/auto_scheduler/search_task.cc
src/auto_scheduler/transform_step.cc
src/auto_scheduler/utils.h
tests/cpp/auto_scheduler_test.cc [new file with mode: 0644]
tests/python/unittest/test_auto_scheduler_compute_dag.py

similarity index 81%
rename from src/auto_scheduler/auto_schedule.h
rename to include/tvm/auto_scheduler/auto_schedule.h
index 55c6992..8477966 100644 (file)
  */
 
 /*!
- * \file auto_scheduler/auto_schedule.h
- * \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get
- * schedule search requirements from upper level (Python API), and returns a high performance
- * schedule after search process.
+ * \file tvm/auto_scheduler/auto_schedule.h
+ * \brief The user interface of the auto scheduler.
  */
 
 #ifndef TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_
 #define TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_
 
-#include <utility>
+#include <tvm/auto_scheduler/measure.h>
+#include <tvm/auto_scheduler/search_policy.h>
 
-#include "measure.h"
-#include "search_policy/search_policy.h"
+#include <utility>
 
 namespace tvm {
 namespace auto_scheduler {
@@ -38,9 +36,9 @@ namespace auto_scheduler {
 /*! \brief Tuning and measurement options. */
 class TuningOptionsNode : public Object {
  public:
-  /*! \brief Number of total measurement trials. */
+  /*! \brief The number of total measurement trials. */
   int num_measure_trials;
-  /*! \brief Stops early the tuning if no improvement after n measurements. */
+  /*! \brief Stops the tuning early if no improvement after n measurements. */
   int early_stopping;
   /*! \brief The number of programs to be measured at each search round. */
   int num_measures_per_round;
@@ -51,7 +49,7 @@ class TuningOptionsNode : public Object {
   int verbose;
   /*! \brief ProgramBuilder which builds the program */
   ProgramBuilder builder;
-  /*! \brief ProgramRunner which runs the program and measure time costs */
+  /*! \brief ProgramRunner which runs the program and measures time costs */
   ProgramRunner runner;
   /*! \brief MeasureCallback functions to be called after each measure batch */
   Optional<Array<MeasureCallback>> measure_callbacks;
@@ -81,8 +79,8 @@ class TuningOptions : public ObjectRef {
  public:
   /*!
    * \brief The constructor
-   * \param num_measure_trials Number of total measurement trials.
-   * \param early_stopping Stops early the tuning if no improvement after n measurements.
+   * \param num_measure_trials The number of total measurement trials.
+   * \param early_stopping Stops the tuning early if no improvement after n measurements.
    * \param num_measures_per_round The number of programs to be measured at each search round.
    * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule
    * search.
@@ -100,11 +98,11 @@ class TuningOptions : public ObjectRef {
 };
 
 /*!
- * \brief Auto schedule search for a given compute declaration.
+ * \brief Run schedule search for a given compute declaration.
  * \param task The search task of the compute declaration.
- * \param search_policy The search policy to be used for schedule search.
+ * \param search_policy The search policy to be used.
  * \param tuning_options Tuning and measurement options.
- * \return A `te::schedule` and the a Array of `te::Tensor` to be used in `tvm.lower` or
+ * \return A `te::schedule` and the an Array of `te::Tensor` to be used in `tvm.lower` or
  * `tvm.build`.
  */
 TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task,
diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h
new file mode 100644 (file)
index 0000000..71652fd
--- /dev/null
@@ -0,0 +1,248 @@
+/*r
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+  template <class T>
+  using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;
+
+  /*! \brief Map an operation to all operations it reads from.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
+   * The inner vector represents the indices of multi-dimensional access.*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+  /*! \brief Map an operation to all operations it is read by.
+   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
+   * The inner vector represents the indices of multi-dimensional access.*/
+  OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+  /*! \brief Store the number of common outer iterators for operation pairs that have
+   * read-write relations. */
+  OperationMap<OperationMap<int>> num_common_outer_iterators;
+  /*! \brief Store whether the operation is an op with only simple access.
+   *  (e.g., injective, broadcast and elementwise ops without reduction) */
+  OperationMap<bool> is_simple_access;
+  /*! \brief Store whether the operation is strictly-inlineable
+   * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations)
+   */
+  OperationMap<bool> is_strict_inlineable;
+  /*! \brief Store whether the operation needs multi-level tiling
+   * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */
+  OperationMap<bool> needs_multi_level_tiling;
+  /*! \brief Store whether the operation is an output operation */
+  OperationMap<bool> is_output;
+  /*! \brief Store the topological order of operations */
+  Array<te::Operation> ops_topo_order;
+
+  static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object);
+};
+
+/*!
+ * \brief Managed reference to AccessAnalyzerNode.
+ * \sa AccessAnalyzerNode
+ */
+class AccessAnalyzer : public ObjectRef {
+ public:
+  explicit AccessAnalyzer(const Array<te::Tensor>& tensors);
+
+  /*!
+   * \brief Return whether this operation is an injective operation
+   * (e.g., injective, broadcast and elementwise ops without reduction)
+   * \param op The operation
+   */
+  TVM_DLL bool IsSimpleAccess(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is strictly inlinable
+   * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations)
+   * \param op The operation
+   */
+  TVM_DLL bool IsStrictInlineable(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation needs multi-level tiling
+   * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d)
+   * \param op The operation
+   */
+  TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
+
+  /*!
+   * \brief Return whether this operation is an output op
+   * \param op The operation
+   */
+  TVM_DLL bool IsOutput(const te::Operation& op) const;
+
+  /*!
+   * \brief Get all consumers of on operation
+   * \param state The current loop state
+   * \param op The operation
+   * \return The set of consumers
+   * \note This function propagates the relation for inlined ops
+   */
+  TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetConsumers(
+      const State& state, const te::Operation& op) const;
+
+  /*!
+   * \brief Get all producers of on operation
+   * \param state The current loop state
+   * \param op The operation
+   * \return The set of producers
+   * \note This function propagates the relation for inlined ops
+   */
+  TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetProducers(
+      const State& state, const te::Operation& op) const;
+
+  /*!
+   * \brief Get all direct producers of on operation
+   * \param op The operation
+   * \return The set of direct producers
+   * \note This function DOES NOT propagate the relation for inlined ops
+   */
+  TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetDirectProducers(
+      const te::Operation& op) const;
+
+  /*!
+   * \brief Get the number of common outer iterators.
+   * \param op The operation
+   * \param target_op The target operation
+   * \note This function propagates the relation for chains with multiple ops.
+   */
+  TVM_DLL int GetNumCommonOuterIterator(const te::Operation& op,
+                                        const te::Operation& target_op) const;
+
+  /*!
+   * \brief Return whether two operations are elementwise-matched
+   *  (e.g. conv2d and relu are elementwise matched)
+   * \note This function propagates the relation for chains with multiple ops.
+   */
+  TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const;
+
+  TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode);
+};
+
+/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */
+class ComputeDAGNode : public Object {
+ public:
+  /*!
+   * \brief Input and output tensors.
+   * This is used as the input of `tvm.lower` or `tvm.build`.
+   */
+  Array<te::Tensor> tensors;
+  /*! \brief All related operations in topo order. */
+  Array<te::Operation> ops;
+  /*! \brief The number of total float operations for this ComputeDAG. */
+  double flop_ct;
+  /*! \brief The initial state without any transform steps. */
+  State init_state;
+  /*! \brief The static read-write access analyzer */
+  AccessAnalyzer access_analyzer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("tensors", &tensors);
+    v->Visit("ops", &ops);
+    v->Visit("flop_ct", &flop_ct);
+    v->Visit("init_state", &init_state);
+  }
+
+  static constexpr const char* _type_key = "auto_scheduler.ComputeDAG";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ComputeDAGNode.
+ * \sa ComputeDAGNode
+ */
+class ComputeDAG : public ObjectRef {
+ public:
+  /*! \brief The constructor.
+   * \param tensors `te::Tensor`s for a compute declaration.
+   */
+  TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);
+
+  /*!
+   * \brief Apply the history transform steps to get a TVM schedule.
+   * \param transform_steps Transform steps of a state.
+   * \param stages The list of stages after applying the steps.
+   * Pass a valid pointer if this information needs to be used outside this function.
+   * \param stage_to_axes The map that stores all axes for one stage.
+   * Pass a valid pointer if this information needs to be used outside this function.
+   * \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower`
+   * or `tvm.build`.
+   */
+  std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
+      const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
+      StageToAxesMap* stage_to_axes = nullptr) const;
+
+  /*!
+   * \brief Print transform steps as equivalent python schedule API.
+   * This can be used for debugging.
+   * \param transform_steps Transform steps of a state.
+   * \return The Python schedule code.
+   */
+  String PrintStepsAsPython(const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound.
+   * The states can lose complete bound information after some transform steps (e.g., compute_at).
+   * We can call this function to infer and fill all the bound information.
+   * This function calls TVM InferBound pass internally to get the bound.
+   * The returned state of this function is guaranteed to have complete bound information.
+   * \param state The input state.
+   * \return The State with complete bound information
+   */
+  State InferBound(const State& state) const;
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
+};
+
+}  // namespace auto_scheduler
+}  // namespace tvm
+
+#endif  // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
similarity index 91%
rename from src/auto_scheduler/loop_state.h
rename to include/tvm/auto_scheduler/loop_state.h
index 4d6477b..4e9cb9b 100644 (file)
@@ -48,6 +48,8 @@
 #ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_
 #define TVM_AUTO_SCHEDULER_LOOP_STATE_H_
 
+#include <dmlc/common.h>
+#include <tvm/auto_scheduler/transform_step.h>
 #include <tvm/runtime/container.h>
 
 #include <functional>
@@ -55,8 +57,6 @@
 #include <utility>
 #include <vector>
 
-#include "transform_step.h"
-
 namespace tvm {
 namespace auto_scheduler {
 
@@ -159,10 +159,16 @@ using IterKey = std::pair<int, int>;
  */
 class AttachMapNode : public Object {
  public:
+  struct IterKeyHash {
+    std::size_t operator()(const IterKey& k) const {
+      return ::dmlc::HashCombine(std::hash<int>()(k.first), std::hash<int>()(k.second));
+    }
+  };
+
   /*! \brief A Map to store the mapping of stage to its attached iterator. */
   std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
   /*! \brief A Map to store the mapping of iterator to the stage attached to it. */
-  std::unordered_map<IterKey, std::vector<StageKey>> iter_to_attached_stages;
+  std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages;
 
   static constexpr const char* _type_key = "auto_scheduler.AttachMap";
   TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object);
@@ -291,14 +297,14 @@ class State : public ObjectRef {
    * this input.
    * \return The iterator result after binded.
    */
-  Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
+  TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
   /*!
    * \brief Schedule primitive corresponds to te.parallel.
    * \param stage_id The index of the stage to be paralleled.
    * \param it The iterator to be paralleled.
    * \return The iterator result after parallel.
    */
-  Iterator parallel(int stage_id, const Iterator& it);
+  TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
   /*!
    * \brief Schedule primitive corresponds to te.unroll.
    * \param stage_id The index of the stage to be unrolled.
@@ -307,14 +313,14 @@ class State : public ObjectRef {
    * skipped.
    * \return The iterator result after unrolled.
    */
-  Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
+  TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
   /*!
    * \brief Schedule primitive corresponds to te.vectorize.
    * \param stage_id The index of the stage to be vectorized.
    * \param it The iterator to be vectorized.
    * \return The iterator result after vectorize.
    */
-  Iterator vectorize(int stage_id, const Iterator& it);
+  TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
   /*!
    * \brief Schedule primitive corresponds to te.fuse.
    * \param stage_id The index of the stage to be fused.
@@ -323,13 +329,13 @@ class State : public ObjectRef {
    * \note If the iterators to be fused have stages attached at them(by compute_at), the fused
    * result will become the new attach point.
    */
-  Iterator fuse(int stage_id, const Array<Iterator>& iters);
+  TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
   /*!
    * \brief Schedule primitive corresponds to te.reorder.
    * \param stage_id The index of the stage to be reordered.
    * \param order The expected iterator order.
    */
-  void reorder(int stage_id, const Array<Iterator>& order);
+  TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
   /*!
    * \brief Schedule primitive corresponds to te.split.
    * \param stage_id The index of the stage to be split.
@@ -340,8 +346,9 @@ class State : public ObjectRef {
    * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
    * most iterator of split results will become the new attach point.
    */
-  Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
-                        bool inner_to_outer = true);
+  TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
+                                const Array<Optional<Integer>>& lengths,
+                                bool inner_to_outer = true);
 
   /********** Step APIs working on multiple stages **********/
 
@@ -355,12 +362,12 @@ class State : public ObjectRef {
    * bound for the newly created iterators.
    * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
    */
-  void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
+  TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
   /*!
    * \brief Schedule primitive corresponds to te.compute_inline.
    * \param stage_id The index of the stage to be reordered.
    */
-  void compute_inline(int stage_id);
+  TVM_DLL void compute_inline(int stage_id);
   /*!
    * \brief Schedule primitive corresponds to te.compute_root.
    * \param stage_id The index of the stage to be reordered.
@@ -369,7 +376,7 @@ class State : public ObjectRef {
    * bound for the newly created iterators.
    * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
    */
-  void compute_root(int stage_id);
+  TVM_DLL void compute_root(int stage_id);
 
   TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
@@ -381,21 +388,11 @@ class State : public ObjectRef {
 // Hash and equal function for State
 namespace std {
 
-/*! \brief The hash function for auto_scheduler::State. */
-template <>
-struct hash<::tvm::auto_scheduler::State> {
-  std::size_t operator()(const ::tvm::auto_scheduler::State& state) const {
-    return tvm::runtime::ObjectHash()(state.ToStr());
-  }
-};
-
 /*!
  * \brief The equal_to function for auto_scheduler::State.
- * We use the schedule result(its string format) of a state to check if two states are `euqal`.
- * Equal States: 1. the transform steps are totally the same; 2. even with different steps, two
- * states may still result in a same schedule. e.g. To split a axis with extent 512 to 3 parts
- * [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can get a same result
- * to split from outter to inner by factors [8, 16])
+ * This function checkes the equality by looking at the lowered string format of states.
+ * If two states with different transform history have the same lowered string format,
+ * they will be considered being equal.
  */
 template <>
 struct equal_to<::tvm::auto_scheduler::State> {
@@ -405,6 +402,14 @@ struct equal_to<::tvm::auto_scheduler::State> {
   }
 };
 
+/*! \brief The hash function for auto_scheduler::State. */
+template <>
+struct hash<::tvm::auto_scheduler::State> {
+  std::size_t operator()(const ::tvm::auto_scheduler::State& state) const {
+    return tvm::runtime::ObjectHash()(state.ToStr());
+  }
+};
+
 }  // namespace std
 
 #endif  // TVM_AUTO_SCHEDULER_LOOP_STATE_H_
similarity index 93%
rename from src/auto_scheduler/measure.h
rename to include/tvm/auto_scheduler/measure.h
index 02d6e87..83d7c8d 100644 (file)
  * These functions are responsible for building the tvm module, uploading it to remote devices,
  * recording the running time costs, and checking the correctness of the output.
  *
- * We separate the measurement into two steps: build and run.
+ * The measurement is separated into two steps: build and run.
  * A builder builds the executable binary files and a runner runs the binary files to get the
  * measurement results. The flow of data structures is
  *
  *                 `ProgramBuilder`                 `ProgramRunner`
  * `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
  *
- * We implement these in python to utilize python's multiprocessing and error handling.
+ * The core functions is implemented in python to utilize python's multiprocessing
+ * and error handling (see also `python/tvm/auto_scheduler/measure.py`).
+ * This c++ file is just a wrapper for the python functions.
  */
 
 #ifndef TVM_AUTO_SCHEDULER_MEASURE_H_
 #define TVM_AUTO_SCHEDULER_MEASURE_H_
 
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/search_task.h>
+
 #include <string>
 #include <unordered_map>
 #include <utility>
 
-#include "loop_state.h"
-#include "search_task.h"
-
 namespace tvm {
 namespace auto_scheduler {
 
@@ -209,7 +211,7 @@ class MeasureCallbackNode : public Object {
  public:
   /*!
    * \brief Callback function that will be called on measurement input/result pairs
-   * after measurement.
+   * after each measurement batch.
    * \param policy The current search policy.
    * \param inputs An Array of MeasureInput.
    * \param results An Array of MeasureResult.
@@ -234,7 +236,7 @@ class MeasureCallback : public ObjectRef {
 /*! \brief ProgramBuilder that builds the programs */
 class ProgramBuilderNode : public Object {
  public:
-  /*! \brief The number of tasks to run in parallel */
+  /*! \brief The number of build processes to run in parallel */
   int n_parallel;
   /*! \brief Timeout of a build */
   int timeout;
@@ -323,15 +325,15 @@ class LocalBuilder : public ProgramBuilder {
    * \brief The constructor.
    * \param timeout The timeout limit (in second) for each build thread.
    * This will be used in a wrapper of the multiprocessing.Process.join().
-   * \param n_parallel Number of threads used to build in parallel.
-   * \param build_func The name of registered build function.
+   * \param n_parallel The number of threads used to build in parallel.
+   * \param build_func The name of the registered build function.
    */
   LocalBuilder(int timeout, int n_parallel, const String& build_func);
 
   TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode);
 };
 
-/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */
+/*! \brief LocalRunner that uses local CPU/GPU to measure the time cost of programs */
 class LocalRunnerNode : public ProgramRunnerNode {
  public:
   Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
@@ -373,13 +375,12 @@ class RPCRunnerNode : public ProgramRunnerNode {
   String key;
   /*! \brief The host address of the RPC Tracker. */
   String host;
-  /*! \brief The port of RPC Tracker. */
+  /*! \brief The port of the RPC Tracker. */
   int port;
   /*! \brief The priority of this run request, larger is more prior. */
   int priority;
   /*! \brief The number of tasks run in parallel. */
   int n_parallel;
-  /*! \brief The number of times to run the generated code for taking average. */
 
   Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
                            const Array<BuildResult>& build_results, int verbose) final;
@@ -395,10 +396,11 @@ class RPCRunnerNode : public ProgramRunnerNode {
 class RPCRunner : public ProgramRunner {
  public:
   /*!
-   * \brief The constructor.
+   * \brief The constructor. See the corresponding class in python/tvm/auto_scheduler/measure.py
+   * for more detailed parameter explaination.
    * \param key The key of the device registered in the RPC tracker.
    * \param host The host address of the RPC Tracker.
-   * \param prot The port of RPC Tracker.
+   * \param port The port of RPC Tracker.
    * \param priority The priority of this run request, larger is more prior.
    * \param n_parallel The number of tasks run in parallel.
    * \param timeout Timeout of a run.
@@ -415,7 +417,7 @@ class RPCRunner : public ProgramRunner {
 
 /*!
  * \brief Measurer that measures the time costs of tvm programs
- * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */
+ * This class combines ProgramBuilder and ProgramRunner and provides a simpler API */
 class ProgramMeasurerNode : public Object {
  public:
   /*! \brief Measured programs counter. */
@@ -483,7 +485,7 @@ class ProgramMeasurer : public ObjectRef {
    * \param callbacks MeasureCallback to be called after each measure batch.
    * \param verbose Verbosity level. 0 for silent, 1 to output information during program
    * measuring.
-   * \param max_continous_error The number of max continuous error.
+   * \param max_continous_error The number of allowed maximum continuous error.
    */
   ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
                   Optional<Array<MeasureCallback>> callbacks, int verbose,
similarity index 83%
rename from src/auto_scheduler/measure_record.h
rename to include/tvm/auto_scheduler/measure_record.h
index 1cfeab0..fa8fe2b 100644 (file)
  */
 
 /*!
- * \file auto_scheduler/measure_record.h
- * \brief Json serialization format for dumping and loading tuning records.
+ * \file tvm/auto_scheduler/measure_record.h
+ * \brief Json serialization format for dumping and loading measurement records.
  */
 
 #ifndef TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_
 #define TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_
 
+#include <tvm/auto_scheduler/measure.h>
+
 #include <fstream>
 #include <string>
 #include <utility>
 
-#include "measure.h"
-
 namespace tvm {
 namespace auto_scheduler {
 
 /*! \brief Callback for logging the input and results of measurements to file */
 class RecordToFileNode : public MeasureCallbackNode {
  public:
-  /*! \brief File name for this callback to write log to. */
+  /*! \brief The name of output file. */
   String filename;
 
   void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
@@ -55,7 +55,7 @@ class RecordToFile : public MeasureCallback {
  public:
   /*!
    * \brief The constructor.
-   * \param filename File name for this callback to write log.
+   * \param filename The name of output file
    */
   explicit RecordToFile(String filename);
 
@@ -65,7 +65,7 @@ class RecordToFile : public MeasureCallback {
 /*! \brief Log reader to load step logs from a file.*/
 class RecordReaderNode : public Object {
  public:
-  /*! \brief File name for this reader to load log from. */
+  /*! \brief The name of input file. */
   String filename;
   /*! \brief The reading file stream. */
   std::ifstream infile;
@@ -92,7 +92,7 @@ class RecordReaderNode : public Object {
   TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object);
 
  private:
-  /*! \brief A string object to store the next line. */
+  /*! \brief A string storing the current line. */
   std::string cur_line_;
 };
 
@@ -104,7 +104,7 @@ class RecordReader : public ObjectRef {
  public:
   /*!
    * \brief The constructor.
-   * \param filename File name for this callback to write log.
+   * \param filename The name of input file
    */
   explicit RecordReader(String filename);
 
@@ -112,7 +112,7 @@ class RecordReader : public ObjectRef {
 };
 
 /*!
- * \brief Write measure records to an output stream.
+ * \brief Append measure records to an output stream.
  * \param os A pointer to a output stream.
  * \param inputs The MeasureInputs to be written.
  * \param results The MeasureResults to be written.
@@ -122,10 +122,10 @@ void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
 
 /*!
  * \brief Read one measure record from a string.
- * \param str The record string to be extract.
- * \param inp A pointer to a MeasureInputNode, this is used as output.
- * \param res A pointer to a MeasureResultNode, this is used as output.
- * \param log_version A pointer to a log version string.
+ * \param str The record string to be parsed.
+ * \param inp A pointer to a MeasureInputNode used to store the return value.
+ * \param res A pointer to a MeasureResultNode used to store the return value.
+ * \param log_version A pointer to a string used to store the log version.
  */
 void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res,
                        std::string* log_version);
  */
 
 /*!
- * \file auto_scheduler/search_policy/search_policy.h
+ * \file tvm/auto_scheduler/search_policy.h
  * \brief The base class of search policies, including the abstract definition of search policy and
  * other supporting data structures.
  *
- * The basic schedule search process for TVM Auto-scheduler is design to be:
+ * The basic schedule search process for the auto-scheduler is design to be:
  * `Program sampling` -> `Performance Tuning`.
  *
  * In `Program sampling`, we use some predefined precise or heuristic rules to generate several
@@ -31,7 +31,7 @@
  *
  * Candidate schedules are measured against the specific hardware target.
  *
- * \note Adding a new search policy.
+ * \note How to add a new search policy.
  * In design, there's no need for users to implement their own search policy, our formal search
  * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule
  * mechanism will be provided to enable user-defined template search to serve the same functionality
  * during the search process.
  */
 
-#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_
-#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_
+#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
+#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
 
+#include <tvm/auto_scheduler/search_task.h>
 #include <tvm/node/node.h>
 
 #include <unordered_set>
 #include <vector>
 
-#include "../search_task.h"
-
 namespace tvm {
 namespace auto_scheduler {
 
@@ -110,16 +109,16 @@ class SearchPolicyNode : public Object {
 
   /*!
    * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state
-   * get during the search process.
-   * \param task  The SearchTask or workload key for the computation declaration
-   * \param num_measure_trials Total schedules to be tried during this search.
-   * \param early_stopping Early stop if no better schedule is found.
-   * \param num_measures_per_round Max measure batch in one search round.
+   * found during the search.
+   * \param task  The SearchTask for the computation declaration
+   * \param num_measure_trials The number of total measurement trials.
+   * \param early_stopping Stops the tuning early if no improvement after n measurements.
+   * \param num_measures_per_round  The number of programs to be measured at each search round.
    * \param verbose Verbose level. 0 for silent, 1 to output information during schedule
    * search.
-   * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside.
+   * \param measurer A ProgramMeasurer to build and measure programs
    * \param pre_search_callbacks SearchCallback to be called before schedule search.
-   * \return The best state get.
+   * \return The best state found.
    */
   virtual State Search(SearchTask task, int num_measure_trials, int early_stopping,
                        int num_measures_per_round, int verbose, ProgramMeasurer measurer,
@@ -137,16 +136,12 @@ class SearchPolicyNode : public Object {
  protected:
   /*!
    * \brief The set of already measured states.
-   * During the schedule search process, we may generate `equal states` through different search
-   * branches. (Equal States: 1. the transform steps are totally the same; 2. even with different
-   * steps, two states may still result in a same schedule. e.g. To split a axis with extent 512
-   * to 3 parts [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can
-   * get a same result to split from outter to inner by factors [8, 16])
    * We store the string format of a state for redundancy check. This is used to make sure a
    * measured state will never be measured again.
    */
   std::unordered_set<String> measured_states_set_;
-  /*! \brief The array of already measured states. This can be used in evolutionary search. */
+  /*! \brief The array of already measured states.
+   *  The good states can be used as the initial population in evolutionary search. */
   std::vector<State> measured_states_vector_;
   /*! \brief The throughputs of already measured states */
   std::vector<float> measured_states_throughputs_;
@@ -164,4 +159,4 @@ class SearchPolicy : public ObjectRef {
 }  // namespace auto_scheduler
 }  // namespace tvm
 
-#endif  // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_
+#endif  // TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
similarity index 97%
rename from src/auto_scheduler/search_task.h
rename to include/tvm/auto_scheduler/search_task.h
index ca31350..85154b5 100644 (file)
 #ifndef TVM_AUTO_SCHEDULER_SEARCH_TASK_H_
 #define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_
 
+#include <tvm/auto_scheduler/compute_dag.h>
 #include <tvm/target/target.h>
 
-#include "compute_dag.h"
-
 namespace tvm {
 namespace auto_scheduler {
 
 class HardwareParams;
 
-/*! \brief The parameters of target hardware used to guide the search process of SearchPolicy. */
+/*! \brief The parameters of target hardware used to guide the SearchPolicy. */
 class HardwareParamsNode : public Object {
  public:
   /*! \brief The number of cores. */
similarity index 98%
rename from src/auto_scheduler/transform_step.h
rename to include/tvm/auto_scheduler/transform_step.h
index ce3ca50..b23137a 100644 (file)
 
 /*!
  * \file auto_scheduler/transform_step.h
- * \brief Transformation steps. For each schedule primitive, there is a corresponding transform
- * step.
+ * \brief Transformation steps. These steps are used to manipulate the LoopState.
+ *        They are similar to the schedule primitives in te::Stage.
  *
- * \note To add a new transform step:
+ * \note How to add a new transform step:
  * Take fuse step for example:
  * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first
  *    construction function `FuseStep::FuseStep()` in `transform_steps.cc`.
@@ -51,8 +51,6 @@
 #include <tvm/node/node.h>
 #include <tvm/te/schedule.h>
 
-#include "utils.h"
-
 namespace tvm {
 namespace auto_scheduler {
 
@@ -187,7 +185,6 @@ Step StepReadFromRecord(dmlc::JSONReader* reader);
  * \param step The step to be applied to State.
  * \param state A mutable pointer to State.
  * \param dag The original ComputeDAG of this state.
- * \return The iterator result after annotate.
  */
 void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
 
@@ -209,7 +206,7 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
 String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
                             StageToAxesMap* stage_to_axes);
 
-/********** Primitives working on single stage **********/
+/********** Steps working on single stage **********/
 
 /*!
  * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding.
@@ -478,7 +475,7 @@ class SplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
-/********** Primitives working on multiple stages **********/
+/********** Steps working on multiple stages **********/
 
 /*! \brief Compute at step that corresponds to te::Stage::compute_at */
 class ComputeAtStepNode : public StepNode {
index d45dbf8..52aa62b 100644 (file)
@@ -57,7 +57,7 @@ class HardwareParams(Object):
 
 @tvm._ffi.register_object("auto_scheduler.SearchTask")
 class SearchTask(Object):
-    """ The computation information and hardware parameters for a specific schedule search task.
+    """ The computation information and hardware parameters for a schedule search task.
 
     Parameters
     ----------
@@ -158,9 +158,6 @@ class TuningOptions(Object):
 def auto_schedule(task, search_policy='default', tuning_options=None):
     """ Do auto scheduling for a computation declaration.
 
-    The task parameter can be a `string` as workload_key, or directly
-    passing a `SearchTask` as input.
-
     Parameters
     ----------
     task : SearchTask
index 36c2037..045720a 100644 (file)
@@ -95,7 +95,7 @@ def make_workload_key(func, args):
 
     Returns
     -------
-    workload_key : Str
+    workload_key : str
         The workload key of the function.
     """
     global WORKLOAD_FUNC_REGISTRY
index b515b3a..c537ca7 100644 (file)
@@ -24,8 +24,7 @@
  * schedule after search process.
  */
 
-#include "auto_schedule.h"
-
+#include <tvm/auto_scheduler/auto_schedule.h>
 #include <tvm/runtime/registry.h>
 
 namespace tvm {
index d81dff6..68d1bb4 100644 (file)
  * \brief Compute declaration graph and its related analysis tools.
  */
 
-#include "compute_dag.h"
-
+#include <tvm/auto_scheduler/compute_dag.h>
+#include <tvm/auto_scheduler/loop_state.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/te/schedule.h>
 #include <tvm/te/schedule_pass.h>
+#include <tvm/tir/builtin.h>
 #include <tvm/tir/stmt_functor.h>
 
 #include <algorithm>
@@ -36,7 +37,7 @@
 #include <unordered_set>
 #include <vector>
 
-#include "loop_state.h"
+#include "../arith/pattern_match.h"
 #include "utils.h"
 
 namespace tvm {
@@ -44,6 +45,10 @@ namespace auto_scheduler {
 
 using namespace tvm::tir;
 
+template <class T>
+using OperationMap = AccessAnalyzerNode::OperationMap<T>;
+using OperationSet = std::unordered_set<te::Operation, ObjectHash, ObjectEqual>;
+
 TVM_REGISTER_NODE_TYPE(ComputeDAGNode);
 
 // Topo-sort ops from tensors according to their read-write relations.
@@ -114,7 +119,416 @@ Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
   return ops;
 }
 
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class ReadAccessExtractor : public StmtExprVisitor {
+ public:
+  void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::if_then_else())) {
+      has_branch = true;
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ProducerLoadNode* op) final {
+    read_access[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+                                                                     op->indices.end());
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const SelectNode* op) final {
+    has_branch = true;
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  // All read accesses to all operations
+  // The innermost vector stores mulit-dimentional indices.
+  // The middle vector stores possible multiple accesses
+  OperationMap<std::vector<std::vector<PrimExpr>>> read_access;
+  // Whether this expression has branch
+  bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with an optional const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+  arith::PVar<PrimExpr> x;
+  arith::PVar<IntImm> c;
+
+  if (((x + c).Match(expr) || (x - c).Match(expr) || (c + x).Match(expr) || x.Match(expr)) &&
+      x.Eval().same_as(var)) {
+    return true;
+  }
+  return false;
+}
+
+// Return whether the access to an operation is a simple access
+// (i.e. all index is just a variable with an optional constant shift)
+// For example, A[i][j], A[i+1][j] are simple accesses but A[i][j+i] is not.
+bool IsSimpleAccess(const te::Operation& op, const std::vector<PrimExpr>& indices,
+                    bool* axis_missing, bool* axis_duplicated, bool* same_order) {
+  auto cop = op.as<te::ComputeOpNode>();
+  if (cop == nullptr) {
+    return false;
+  }
+
+  std::vector<int> index_to_var_idx;
+  std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+  for (const auto& expr : indices) {
+    if (!is_const_int(expr)) {
+      bool found = false;
+      for (size_t i = 0; i < cop->axis.size(); ++i) {
+        if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+          index_to_var_idx.push_back(i);
+          var_idx_ct[i]++;
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        return false;
+      }
+    }
+  }
+
+  *axis_missing = false;     // Some axes are missing
+  *axis_duplicated = false;  // Some axes appear more than once
+  *same_order = true;        // The axis order is the same as op->axis
+  for (int ct : var_idx_ct) {
+    if (ct == 0) {
+      *axis_missing = true;
+    } else if (ct > 1) {
+      *axis_duplicated = true;
+    }
+  }
+  for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+    if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+      *same_order = false;
+      break;
+    }
+  }
+
+  return true;
+}
+
+// Gather all VarNodes in an expr
+void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+  PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+    if (const VarNode* op = node.as<VarNode>()) {
+      vars->insert(op);
+    }
+  });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+bool HasExpensiveOp(const PrimExpr& expr) {
+  bool found = false;
+  PostOrderVisit(expr, [&found](const ObjectRef& node) {
+    if (const CallNode* op = node.as<CallNode>()) {
+      if (op->op.as<OpNode>()->name == "tir.exp") {
+        found = true;
+      }
+    }
+  });
+  return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+  auto node = make_object<AccessAnalyzerNode>();
+  OperationMap<bool> has_branch;
+
+  // Get all ops in topological order
+  node->ops_topo_order = TopoSortOps(tensors);
+
+  arith::Analyzer analyzer;
+
+  // Build read & write access map
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      ReadAccessExtractor extractor;
+      for (const auto& exp : cop->body) {
+        extractor.Extract(exp);
+      }
+
+      // read_by and read_from map
+      for (const auto& iter : extractor.read_access) {
+        std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+        accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+      }
+
+      node->read_from[op] = std::move(extractor.read_access);
+      has_branch[op] = extractor.has_branch;
+
+      // compute number of common outer iterators
+      for (const auto& pair : node->read_from[op]) {
+        const te::Operation& producer = pair.first;
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        const Array<PrimExpr>& output_shape = op->output_shape(0);
+        const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+        int n_common;
+        for (n_common = 0;
+             n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));
+             n_common++) {
+          if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) {
+            break;
+          }
+
+          bool injective = true;
+          for (const auto& access : access_list) {
+            if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) {
+              injective = false;
+              break;
+            }
+          }
+
+          if (!injective) {
+            break;
+          }
+        }
+
+        node->num_common_outer_iterators[op][producer] = n_common;
+        node->num_common_outer_iterators[producer][op] = n_common;
+      }
+    } else {
+      LOG(FATAL) << "Invalid op: " << op;
+    }
+  }
+
+  // Do some static analysis on ComputeOps
+  for (const auto& op : node->ops_topo_order) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      node->is_simple_access[op] = true;
+      node->needs_multi_level_tiling[op] = false;
+      node->is_strict_inlineable[op] = false;
+      node->is_output[op] = false;
+    } else if (auto cop = op.as<te::ComputeOpNode>()) {
+      // check whether this op is element-wise and strict-inlineable
+      bool is_simple_access = true;
+      bool is_strict_inlineable = true;
+
+      bool axis_missing, axis_duplicated, same_order;
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        for (const auto& access : access_list) {
+          if (!auto_scheduler::IsSimpleAccess(op, access, &axis_missing, &axis_duplicated,
+                                              &same_order)) {
+            is_simple_access = false;
+            is_strict_inlineable = false;
+            break;
+          }
+          if (!same_order || axis_duplicated) {
+            // do not strictly inline transpose
+            is_strict_inlineable = false;
+          }
+        }
+        if (!is_simple_access) {
+          break;
+        }
+      }
+
+      // don't strictly inline expensive op (e.g. exp)
+      bool has_expensive_op = false;
+      for (const auto& expr : cop->body) {
+        has_expensive_op |= HasExpensiveOp(expr);
+      }
+      if (has_expensive_op || has_branch[op]) {
+        is_strict_inlineable = false;
+      }
+
+      node->is_simple_access[op] = is_simple_access;
+      node->is_strict_inlineable[op] = is_strict_inlineable;
+
+      // check whether the op needs multi-level tiling
+      bool needs_multi_level_tiling = false;
+      int n_missing = 0;
+
+      for (const auto& pair : node->read_from[op]) {
+        const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+        std::unordered_set<const VarNode*> vars;
+        for (const std::vector<PrimExpr>& access : access_list) {
+          for (const PrimExpr& expr : access) {
+            GatherVars(expr, &vars);
+          }
+        }
+
+        for (const auto& axis : cop->axis) {
+          if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) {
+            n_missing++;
+            break;
+          }
+        }
+
+        if (n_missing >= 2 || (n_missing >= 1 && !cop->reduce_axis.empty())) {
+          needs_multi_level_tiling = true;
+          break;
+        }
+      }
+
+      node->needs_multi_level_tiling[op] = needs_multi_level_tiling;
+
+      // check whether the op is output
+      node->is_output[op] = node->read_by[op].empty();
+    } else {
+      LOG(FATAL) << "Invalid op" << op;
+    }
+  }
+
+  data_ = std::move(node);
+}
+
+bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation& op) const {
+  return operator->()->needs_multi_level_tiling.at(op);
+}
+
+bool AccessAnalyzer::IsOutput(const te::Operation& op) const {
+  return operator->()->is_output.at(op);
+}
+
+bool AccessAnalyzer::IsSimpleAccess(const te::Operation& op) const {
+  return operator->()->is_simple_access.at(op);
+}
+
+bool AccessAnalyzer::IsStrictInlineable(const te::Operation& op) const {
+  return operator->()->is_strict_inlineable.at(op);
+}
+
+OperationSet AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op) const {
+  OperationSet inlined_ops;
+  for (const auto& stage : state->stages) {
+    if (stage->compute_at == ComputeAtKind::kInlined) {
+      inlined_ops.insert(stage->op);
+    }
+  }
+
+  OperationSet consumers;
+  std::function<void(const te::Operation&)> collect;
+  collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) {
+    for (const auto& iter : operator->()->read_by.at(op)) {
+      if (inlined_ops.count(iter.first)) {
+        collect(iter.first);
+      } else {
+        consumers.insert(iter.first);
+      }
+    }
+  };
+
+  collect(op);
+  return consumers;
+}
+
+OperationSet AccessAnalyzer::GetDirectProducers(const te::Operation& op) const {
+  OperationSet producers;
+  for (const auto& iter : operator->()->read_from.at(op)) {
+    producers.insert(iter.first);
+  }
+  return producers;
+}
+
+OperationSet AccessAnalyzer::GetProducers(const State& state, const te::Operation& op) const {
+  OperationSet inlined_ops;
+  for (const auto& stage : state->stages) {
+    if (stage->compute_at == ComputeAtKind::kInlined) {
+      inlined_ops.insert(stage->op);
+    }
+  }
+
+  OperationSet producers;
+  std::function<void(const te::Operation&)> collect;
+  collect = [this, &collect, &inlined_ops, &producers](const te::Operation& op) {
+    for (const auto& iter : operator->()->read_from.at(op)) {
+      if (inlined_ops.count(iter.first)) {
+        collect(iter.first);
+      } else {
+        producers.insert(iter.first);
+      }
+    }
+  };
+
+  collect(op);
+  return producers;
+}
+
+int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op,
+                                              const te::Operation& target_op) const {
+  int ret = INT32_MAX;
+  bool meet = false;
+
+  std::function<void(const te::Operation&, int)> traverse;
+  traverse = [this, &traverse, &target_op, &ret, &meet](const te::Operation& cur_op, int cur_num) {
+    if (cur_op == target_op) {
+      ret = std::min(ret, cur_num);
+      meet = true;
+      return;
+    }
+
+    for (const auto& iter : operator->()->read_by.at(cur_op)) {
+      traverse(
+          iter.first,
+          std::min(cur_num, operator->()->num_common_outer_iterators.at(cur_op).at(iter.first)));
+    }
+  };
+
+  traverse(op, op->output_shape(0).size());
+  return meet ? ret : 0;
+}
+
+bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op,
+                                      const te::Operation& target_op) const {
+  te::Operation cur_op = op;
+  while (cur_op != target_op) {
+    const AccessAnalyzerNode::OperationMap<std::vector<std::vector<PrimExpr>>>& map =
+    operator->()->read_by.at(cur_op);
+
+    if (map.size() != 1) {
+      return false;
+    }
+    te::Operation next_op = map.begin()->first;
+
+    // Check condition 1: They have the same output size
+    auto p_cur = cur_op.as<te::ComputeOpNode>();
+    auto p_next = next_op.as<te::ComputeOpNode>();
+    if (p_cur == nullptr || p_next == nullptr) {
+      return false;
+    }
+
+    Array<PrimExpr> output_shape = p_cur->output_shape(0);
+    for (int i = 1; i < p_cur->num_outputs(); ++i) {
+      if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) {
+        return false;
+      }
+    }
+    for (int i = 0; i < p_next->num_outputs(); ++i) {
+      if (!IntArrayEqual(p_next->output_shape(i), output_shape)) {
+        return false;
+      }
+    }
+
+    // Check condition 2: The read is elementwise
+    const std::vector<std::vector<PrimExpr>> reads = map.begin()->second;
+    bool is_simple_access, axis_missing, axis_duplicated, same_order;
+    for (const auto& read : reads) {
+      is_simple_access = auto_scheduler::IsSimpleAccess(next_op, read, &axis_missing,
+                                                        &axis_duplicated, &same_order);
+      if (!is_simple_access || axis_missing || axis_duplicated || !same_order) {
+        return false;
+      }
+    }
+
+    cur_op = std::move(next_op);
+  }
+  return true;
+}
+
+// Estimate the number of float operations in an expression
 class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
  public:
   double EstimateFlop(const Array<te::Operation>& ops) {
@@ -126,6 +540,7 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
           fail_ = true;
           break;
         }
+        cur_type_code_ = pop->output_dtype(0).code();
         double op_per_element = 0;
         for (const auto& x : pop->body) {
           op_per_element += VisitExpr(x);
@@ -171,10 +586,17 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
            std::max(VisitExpr(op->true_value), VisitExpr(op->false_value));
   }
 
-#define VisitBinary(Node) \
-  double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); }
-#define VisitUnary(Node) \
-  double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a); }
+#define VisitBinary(Node)                                         \
+  double VisitExpr_(const Node* op) final {                       \
+    double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \
+    return base + VisitExpr(op->a) + VisitExpr(op->b);            \
+  }
+
+#define VisitUnary(Node)                                          \
+  double VisitExpr_(const Node* op) final {                       \
+    double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \
+    return base + VisitExpr(op->a);                               \
+  }
 
   VisitBinary(AddNode);
   VisitBinary(SubNode);
@@ -210,12 +632,14 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
 
  private:
   bool fail_{false};
+  int cur_type_code_;
 };
 
 ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   auto node = make_object<ComputeDAGNode>();
   node->tensors = std::move(tensors);
-  node->ops = TopoSortOps(node->tensors);
+  node->access_analyzer = AccessAnalyzer(node->tensors);
+  node->ops = node->access_analyzer->ops_topo_order;
   node->flop_ct = FlopEstimator().EstimateFlop(node->ops);
   node->init_state = State(node->ops);
   data_ = std::move(node);
diff --git a/src/auto_scheduler/compute_dag.h b/src/auto_scheduler/compute_dag.h
deleted file mode 100644 (file)
index 2417d72..0000000
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file auto_scheduler/compute_dag.h
- * \brief The TVM Auto-scheduler computational graph and related program analyses.
- *
- * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
- * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
- * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
- * total float operation count, consumer/producer relations of each operation stage, whether an
- * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
- * to make decisions during search process.
- * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
- * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
- * `LoopState` with extra information got from TVM schedule ...).
- */
-
-#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
-#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
-
-#include <tvm/te/schedule.h>
-
-#include <utility>
-
-#include "loop_state.h"
-
-namespace tvm {
-namespace auto_scheduler {
-
-/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */
-class ComputeDAGNode : public Object {
- public:
-  /*!
-   * \brief Input and output tensors.
-   * This is used as the input of `tvm.lower` or `tvm.build`.
-   */
-  Array<te::Tensor> tensors;
-  /*! \brief All related operations in topo order. */
-  Array<te::Operation> ops;
-  /*! \brief Number of total float operations for this ComputeDAG. */
-  double flop_ct;
-  /*! \brief The initial state without any transform steps. */
-  State init_state;
-  // TODO(merrymercy): Add more analyses later.
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("tensors", &tensors);
-    v->Visit("ops", &ops);
-    v->Visit("flop_ct", &flop_ct);
-    v->Visit("init_state", &init_state);
-  }
-
-  static constexpr const char* _type_key = "auto_scheduler.ComputeDAG";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object);
-};
-
-/*!
- * \brief Managed reference to ComputeDAGNode.
- * \sa ComputeDAGNode
- */
-class ComputeDAG : public ObjectRef {
- public:
-  /*! \brief The constructor.
-   * \param tensors `te::Tensor`s for a compute declaration.
-   */
-  explicit ComputeDAG(Array<te::Tensor> tensors);
-
-  /*!
-   * \brief Apply the history transform steps from a State to get a TVM schedule.
-   * \param transform_steps Transform steps of a state.
-   * \param stages A pointer to a `te::Stage` Array, default to be nullptr.
-   * Pass a valid pointer if these information needs to be used outside this function.
-   * \param stage_to_axes A pointer to a StageToAxesMap, default to be nullptr.
-   * Pass a valid pointer if these information needs to be used outside this function.
-   * \return A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
-   */
-  std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
-      const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
-      StageToAxesMap* stage_to_axes = nullptr) const;
-
-  /*!
-   * \brief Print transform steps as equivalent python schedule API.
-   * This can be used for debugging.
-   * \param transform_steps Transform steps of a state.
-   * \return The Python schedule code.
-   */
-  String PrintStepsAsPython(const Array<Step>& transform_steps) const;
-
-  /*!
-   * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound.
-   * The states can lose complete bound information after some transform steps (e.g., compute_at).
-   * We can call this function to infer and fill all the bound information.
-   * This function calls TVM InferBound pass internally to get the bound.
-   * The returned state of this function is guaranteed to have complete iterator extent information.
-   * \param state The state to.
-   * \return The State after inferbound.
-   */
-  State InferBound(const State& state) const;
-
-  TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
-  TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
-};
-
-}  // namespace auto_scheduler
-}  // namespace tvm
-
-#endif  // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
index bfe5478..35d899a 100644 (file)
  * see auto_scheduler/loop_state.h for more explanation.
  */
 
-#include "loop_state.h"
-
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/transform_step.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 
 #include <utility>
 
-#include "transform_step.h"
 #include "utils.h"
 
 namespace tvm {
index 6198f60..e249f7b 100644 (file)
@@ -22,8 +22,7 @@
  * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs.
  */
 
-#include "measure.h"
-
+#include <tvm/auto_scheduler/measure.h>
 #include <tvm/runtime/registry.h>
 
 #include <algorithm>
index 39f9ad8..02f244f 100644 (file)
  * \brief Json serialization format for dumping and loading tuning records.
  */
 
-#include "measure_record.h"
-
 #include <dmlc/json.h>
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/measure_record.h>
+#include <tvm/auto_scheduler/transform_step.h>
 #include <tvm/runtime/registry.h>
 
 #include <fstream>
@@ -33,8 +34,6 @@
 #include <utility>
 #include <vector>
 
-#include "loop_state.h"
-#include "transform_step.h"
 #include "utils.h"
 
 // Json serialization handler for MeasureInput, MeasureResult
index 1886203..4c85af4 100644 (file)
 
 #include "empty_policy.h"
 
+#include <tvm/auto_scheduler/measure.h>
 #include <tvm/runtime/registry.h>
 
-#include "../measure.h"
-
 namespace tvm {
 namespace auto_scheduler {
 
index 4ccc9c1..ef7d38d 100644 (file)
@@ -26,8 +26,8 @@
 #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_
 #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_
 
-#include "../loop_state.h"
-#include "search_policy.h"
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/search_policy.h>
 
 namespace tvm {
 namespace auto_scheduler {
index fba5155..764b0a7 100644 (file)
@@ -22,8 +22,7 @@
  * \brief The base class of search policies.
  */
 
-#include "search_policy.h"
-
+#include <tvm/auto_scheduler/search_policy.h>
 #include <tvm/runtime/registry.h>
 
 namespace tvm {
index 912a310..9cc21f2 100644 (file)
@@ -22,8 +22,7 @@
  * \brief Meta information and hardware parameters for a search task.
  */
 
-#include "search_task.h"
-
+#include <tvm/auto_scheduler/search_task.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/threading_backend.h>
 
index 6c672a5..b1b3b94 100644 (file)
 
 /*!
  * \file auto_scheduler/transform_step.cc
- * \brief Transformation steps. For each schedule primitive, there is a corresponding transform
- * step.
+ * \brief Transformation steps. These steps are used to manipulate the LoopState.
+ *        They are similar to the schedule primitives in te::Stage.
  */
 
-#include "transform_step.h"
-
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/transform_step.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 
@@ -32,7 +32,6 @@
 #include <utility>
 #include <vector>
 
-#include "loop_state.h"
 #include "utils.h"
 
 namespace tvm {
@@ -80,6 +79,7 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) {
 }
 
 void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
+  // We need this runtime dispatcher because different steps have different function signatures
   if (auto ps = step.as<AnnotationStepNode>()) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<FuseStepNode>()) {
@@ -101,6 +101,7 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
 
 void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages,
                          StageToAxesMap* stage_to_axes) {
+  // We need this runtime dispatcher because different steps have different function signatures
   if (auto ps = step.as<AnnotationStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<FuseStepNode>()) {
@@ -122,6 +123,7 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages,
 
 String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
                             StageToAxesMap* stage_to_axes) {
+  // We need this runtime dispatcher because different steps have different function signatures
   if (auto ps = step.as<AnnotationStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<FuseStepNode>()) {
@@ -142,7 +144,7 @@ String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
   return "";
 }
 
-/********** Primitives working on single stage **********/
+/********** Steps working on single stage **********/
 
 /********** Annotation **********/
 AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) {
@@ -741,7 +743,7 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
 }
 
-/********** Primitives working on multiple stages **********/
+/********** Steps working on multiple stages **********/
 
 /********** Compute At **********/
 ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) {
index de800da..da5032e 100644 (file)
@@ -128,6 +128,24 @@ inline std::vector<int> IntArrayToVector(
   return out;
 }
 
+/*! \brief Return whether two int arrays are elementwise-equal */
+inline bool IntArrayEqual(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
+  if (arr1.size() != arr2.size()) {
+    return false;
+  }
+
+  for (size_t i = 0; i < arr1.size(); ++i) {
+    auto int1 = arr1[i].as<IntImmNode>();
+    auto int2 = arr2[i].as<IntImmNode>();
+    CHECK(int1 != nullptr);
+    CHECK(int2 != nullptr);
+    if (int1->value != int2->value) {
+      return false;
+    }
+  }
+  return true;
+}
+
 /********** Utilities for TVM Containers / ByteArray **********/
 /*! \brief Compute mean of a FloatImm array */
 inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc
new file mode 100644 (file)
index 0000000..8526605
--- /dev/null
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <dmlc/logging.h>
+#include <gtest/gtest.h>
+#include <topi/nn.h>
+#include <tvm/auto_scheduler/compute_dag.h>
+#include <tvm/runtime/container.h>
+#include <tvm/te/operation.h>
+
+#include <unordered_set>
+
+// Compute declaration for test
+tvm::Array<tvm::te::Tensor> conv2d_nchw_bn_relu_func(int N, int H, int W, int CI, int CO,
+                                                     int kernel_size, int strides, int padding,
+                                                     int dilation = 1) {
+  using namespace tvm;
+  using namespace tvm::te;
+
+  Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data");
+  Tensor kernel = placeholder({CO, CI, kernel_size, kernel_size}, DataType::Float(32), "Kernel");
+  Tensor bias = placeholder({CO, 1, 1}, DataType::Float(32), "Bias");
+  Tensor bn_scale = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_scale");
+  Tensor bn_offset = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_offset");
+
+  int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1;
+  int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1;
+
+  const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, strides, strides);
+  CHECK(conv->shape[2].as<IntImmNode>()->value == OH);
+  CHECK(conv->shape[3].as<IntImmNode>()->value == OW);
+
+  const auto& bias_add = compute(
+      {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { return conv[i][j][k][l] + bias[j][0][0]; },
+      "Bias_add");
+  const auto& bn_mul = compute(
+      {N, CO, OH, OW},
+      [&](Var i, Var j, Var k, Var l) { return bias_add[i][j][k][l] * bn_scale[j][0][0]; },
+      "Bn_mul");
+  const auto& bn_add = compute(
+      {N, CO, OH, OW},
+      [&](Var i, Var j, Var k, Var l) { return bn_mul[i][j][k][l] + bn_offset[j][0][0]; },
+      "Bn_add");
+  const auto& out = topi::relu<float>(bn_add);
+
+  return {data, kernel, bias, bn_scale, bn_offset, out};
+}
+
+using namespace tvm::auto_scheduler;
+
+// Test Access Analyzer
+TEST(ComputeDAG, AccessAnalyzer) {
+  const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3);
+  const auto& dag = tvm::auto_scheduler::ComputeDAG(tensors);
+  State s0 = dag->init_state;
+
+  int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5;
+  int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10;
+
+  std::set<int> needs_multi_level_tiling = {conv};
+  for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) {
+    if (needs_multi_level_tiling.count(stage_id)) {
+      CHECK(dag->access_analyzer.NeedsMultiLevelTiling(dag->ops[stage_id]));
+    } else {
+      CHECK(!dag->access_analyzer.NeedsMultiLevelTiling(dag->ops[stage_id]));
+    }
+  }
+
+  std::set<int> is_simple_access = {data,     padding, kernel,    bias,   bias_add,
+                                    bn_scale, bn_mul,  bn_offset, bn_add, relu};
+  for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) {
+    if (is_simple_access.count(stage_id)) {
+      CHECK(dag->access_analyzer.IsSimpleAccess(dag->ops[stage_id]));
+    } else {
+      CHECK(!dag->access_analyzer.IsSimpleAccess(dag->ops[stage_id]));
+    }
+  }
+
+  std::set<int> is_strictly_inlinable = {bias_add, bn_mul, bn_add, relu};
+  for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) {
+    if (is_strictly_inlinable.count(stage_id)) {
+      CHECK(dag->access_analyzer.IsStrictInlineable(dag->ops[stage_id]));
+    } else {
+      CHECK(!dag->access_analyzer.IsStrictInlineable(dag->ops[stage_id]));
+    }
+  }
+
+  std::set<int> is_output = {relu};
+  for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) {
+    if (is_output.count(stage_id)) {
+      CHECK(dag->access_analyzer.IsOutput(dag->ops[stage_id]));
+    } else {
+      CHECK(!dag->access_analyzer.IsOutput(dag->ops[stage_id]));
+    }
+  }
+
+  CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[conv], dag->ops[bias_add]), 4);
+  CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[conv], dag->ops[relu]), 4);
+  CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[data], dag->ops[relu]), 1);
+
+  CHECK(dag->access_analyzer.ElementWiseMatch(dag->ops[conv], dag->ops[bias_add]));
+  CHECK(dag->access_analyzer.ElementWiseMatch(dag->ops[conv], dag->ops[relu]));
+  CHECK(!dag->access_analyzer.ElementWiseMatch(dag->ops[data], dag->ops[padding]));
+
+  std::unordered_set<tvm::te::Operation, tvm::ObjectHash, tvm::ObjectEqual> op_set;
+  {
+    std::vector<std::pair<int, int>> consumer_list = {
+        {data, padding},     {padding, conv},    {kernel, conv},     {conv, bias_add},
+        {bias, bias_add},    {bias_add, bn_mul}, {bn_scale, bn_mul}, {bn_mul, bn_add},
+        {bn_offset, bn_add}, {bn_add, relu}};
+    for (const auto& pair : consumer_list) {
+      op_set = dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op);
+      CHECK_EQ(op_set.size(), 1);
+      CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op);
+    }
+    std::vector<std::pair<int, std::vector<int>>> producer_list = {{padding, {data}},
+                                                                   {conv, {padding, kernel}},
+                                                                   {bias_add, {conv, bias}},
+                                                                   {bn_mul, {bias_add, bn_scale}},
+                                                                   {bn_add, {bn_mul, bn_offset}},
+                                                                   {relu, {bn_add}}};
+    for (const auto& pair : producer_list) {
+      op_set = dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op);
+      CHECK_EQ(op_set.size(), pair.second.size());
+      for (const auto& target : pair.second) {
+        CHECK(op_set.count(s0->stages[target]->op));
+      }
+    }
+  }
+
+  s0.compute_inline(bn_add);
+  s0.compute_inline(bn_mul);
+  s0.compute_inline(bias_add);
+  s0.compute_inline(padding);
+  {
+    std::vector<std::pair<int, int>> consumer_list = {{data, conv}, {kernel, conv}, {conv, relu}};
+    for (const auto& pair : consumer_list) {
+      op_set = dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op);
+      CHECK_EQ(op_set.size(), 1);
+      CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op);
+    }
+    std::vector<std::pair<int, std::vector<int>>> producer_list = {{padding, {data}},
+                                                                   {conv, {padding, kernel}},
+                                                                   {bias_add, {conv, bias}},
+                                                                   {bn_mul, {bias_add, bn_scale}},
+                                                                   {bn_add, {bn_mul, bn_offset}},
+                                                                   {relu, {bn_add}}};
+    for (const auto& pair : producer_list) {
+      op_set = dag->access_analyzer.GetDirectProducers(s0->stages[pair.first]->op);
+      CHECK_EQ(op_set.size(), pair.second.size());
+      for (const auto& target : pair.second) {
+        CHECK(op_set.count(s0->stages[target]->op));
+      }
+    }
+  }
+}
+
+int main(int argc, char** argv) {
+  testing::InitGoogleTest(&argc, argv);
+  testing::FLAGS_gtest_death_test_style = "threadsafe";
+  return RUN_ALL_TESTS();
+}
index 4934463..d9c24b9 100644 (file)
 
 """Test ComputeDAG (replay, infer bound)"""
 
-import tvm
+import tvm, topi
 from tvm import auto_scheduler, te
 
-from test_auto_scheduler_common import get_tiled_matmul
+from test_auto_scheduler_common import get_tiled_matmul, matmul_auto_scheduler_test
 
 
 def test_apply_steps():
@@ -36,8 +36,19 @@ def test_infer_bound():
 
 
 def test_estimate_flop():
-    dag, s = get_tiled_matmul()
-    assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5
+    N = 512
+    A, B, C = matmul_auto_scheduler_test(N, N, N)
+    dag = auto_scheduler.ComputeDAG([A, B, C])
+    assert abs(dag.flop_ct - 2 * N ** 3) < 0.5
+
+    D = topi.nn.relu(C)
+    dag = auto_scheduler.ComputeDAG([A, B, D])
+    assert abs(dag.flop_ct - 2 * N ** 3 - N * N) < 0.5
+
+    # should not count the comparison operations in padding
+    D = topi.nn.pad(C, [1, 1])
+    dag = auto_scheduler.ComputeDAG([A, B, D])
+    assert abs(dag.flop_ct - 2 * N ** 3) < 0.5
 
 
 if __name__ == "__main__":