From c9cbd040be2a76029ceb56eadaa6ded96945dfe0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 25 Jul 2020 05:07:17 -0700 Subject: [PATCH] [Ansor][AutoTVM v2.0] Phase 1: Access Analyzer (#6103) * add access analyzer * add test cases * move header files and polish comments * fix lint * update * fix lint * address comments * fix lint --- .../tvm}/auto_scheduler/auto_schedule.h | 28 +- include/tvm/auto_scheduler/compute_dag.h | 248 ++++++++++++ {src => include/tvm}/auto_scheduler/loop_state.h | 59 +-- {src => include/tvm}/auto_scheduler/measure.h | 34 +- .../tvm}/auto_scheduler/measure_record.h | 28 +- .../tvm/auto_scheduler}/search_policy.h | 37 +- {src => include/tvm}/auto_scheduler/search_task.h | 5 +- .../tvm}/auto_scheduler/transform_step.h | 13 +- python/tvm/auto_scheduler/auto_schedule.py | 5 +- python/tvm/auto_scheduler/workload_registry.py | 2 +- src/auto_scheduler/auto_schedule.cc | 3 +- src/auto_scheduler/compute_dag.cc | 442 ++++++++++++++++++++- src/auto_scheduler/compute_dag.h | 124 ------ src/auto_scheduler/loop_state.cc | 5 +- src/auto_scheduler/measure.cc | 3 +- src/auto_scheduler/measure_record.cc | 7 +- src/auto_scheduler/search_policy/empty_policy.cc | 3 +- src/auto_scheduler/search_policy/empty_policy.h | 4 +- src/auto_scheduler/search_policy/search_policy.cc | 3 +- src/auto_scheduler/search_task.cc | 3 +- src/auto_scheduler/transform_step.cc | 16 +- src/auto_scheduler/utils.h | 18 + tests/cpp/auto_scheduler_test.cc | 178 +++++++++ .../unittest/test_auto_scheduler_compute_dag.py | 19 +- 24 files changed, 1015 insertions(+), 272 deletions(-) rename {src => include/tvm}/auto_scheduler/auto_schedule.h (81%) create mode 100644 include/tvm/auto_scheduler/compute_dag.h rename {src => include/tvm}/auto_scheduler/loop_state.h (91%) rename {src => include/tvm}/auto_scheduler/measure.h (93%) rename {src => include/tvm}/auto_scheduler/measure_record.h (83%) rename {src/auto_scheduler/search_policy => include/tvm/auto_scheduler}/search_policy.h (79%) rename {src => include/tvm}/auto_scheduler/search_task.h (97%) rename {src => include/tvm}/auto_scheduler/transform_step.h (98%) delete mode 100644 src/auto_scheduler/compute_dag.h create mode 100644 tests/cpp/auto_scheduler_test.cc diff --git a/src/auto_scheduler/auto_schedule.h b/include/tvm/auto_scheduler/auto_schedule.h similarity index 81% rename from src/auto_scheduler/auto_schedule.h rename to include/tvm/auto_scheduler/auto_schedule.h index 55c6992..8477966 100644 --- a/src/auto_scheduler/auto_schedule.h +++ b/include/tvm/auto_scheduler/auto_schedule.h @@ -18,19 +18,17 @@ */ /*! - * \file auto_scheduler/auto_schedule.h - * \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get - * schedule search requirements from upper level (Python API), and returns a high performance - * schedule after search process. + * \file tvm/auto_scheduler/auto_schedule.h + * \brief The user interface of the auto scheduler. */ #ifndef TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_ #define TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_ -#include +#include +#include -#include "measure.h" -#include "search_policy/search_policy.h" +#include namespace tvm { namespace auto_scheduler { @@ -38,9 +36,9 @@ namespace auto_scheduler { /*! \brief Tuning and measurement options. */ class TuningOptionsNode : public Object { public: - /*! \brief Number of total measurement trials. */ + /*! \brief The number of total measurement trials. */ int num_measure_trials; - /*! \brief Stops early the tuning if no improvement after n measurements. */ + /*! \brief Stops the tuning early if no improvement after n measurements. */ int early_stopping; /*! \brief The number of programs to be measured at each search round. */ int num_measures_per_round; @@ -51,7 +49,7 @@ class TuningOptionsNode : public Object { int verbose; /*! \brief ProgramBuilder which builds the program */ ProgramBuilder builder; - /*! \brief ProgramRunner which runs the program and measure time costs */ + /*! \brief ProgramRunner which runs the program and measures time costs */ ProgramRunner runner; /*! \brief MeasureCallback functions to be called after each measure batch */ Optional> measure_callbacks; @@ -81,8 +79,8 @@ class TuningOptions : public ObjectRef { public: /*! * \brief The constructor - * \param num_measure_trials Number of total measurement trials. - * \param early_stopping Stops early the tuning if no improvement after n measurements. + * \param num_measure_trials The number of total measurement trials. + * \param early_stopping Stops the tuning early if no improvement after n measurements. * \param num_measures_per_round The number of programs to be measured at each search round. * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule * search. @@ -100,11 +98,11 @@ class TuningOptions : public ObjectRef { }; /*! - * \brief Auto schedule search for a given compute declaration. + * \brief Run schedule search for a given compute declaration. * \param task The search task of the compute declaration. - * \param search_policy The search policy to be used for schedule search. + * \param search_policy The search policy to be used. * \param tuning_options Tuning and measurement options. - * \return A `te::schedule` and the a Array of `te::Tensor` to be used in `tvm.lower` or + * \return A `te::schedule` and the an Array of `te::Tensor` to be used in `tvm.lower` or * `tvm.build`. */ TVM_DLL std::pair> AutoSchedule(SearchTask task, diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h new file mode 100644 index 0000000..71652fd --- /dev/null +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -0,0 +1,248 @@ +/*r + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/auto_scheduler/compute_dag.h + * \brief The auto-scheduler's computational graph and related program analyses. + * + * We convert a compute declaration described by `tvm.compute` (could be a single operator or a + * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, + * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the + * total float operation count, consumer/producer relations of each operation stage, whether an + * operation stage should be tiled/compute inlined ...). These analyses can help the search policy + * to make decisions during search process. + * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and + * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing + * `LoopState` with extra information got from TVM schedule ...). + */ + +#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ +#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ + +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace auto_scheduler { + +/*! \brief Static analysis result for a ComputeDAG */ +class AccessAnalyzerNode : public Object { + public: + template + using OperationMap = std::unordered_map; + + /*! \brief Map an operation to all operations it reads from. + * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses + * The inner vector represents the indices of multi-dimensional access.*/ + OperationMap>>> read_from; + /*! \brief Map an operation to all operations it is read by. + * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses + * The inner vector represents the indices of multi-dimensional access.*/ + OperationMap>>> read_by; + /*! \brief Store the number of common outer iterators for operation pairs that have + * read-write relations. */ + OperationMap> num_common_outer_iterators; + /*! \brief Store whether the operation is an op with only simple access. + * (e.g., injective, broadcast and elementwise ops without reduction) */ + OperationMap is_simple_access; + /*! \brief Store whether the operation is strictly-inlineable + * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) + */ + OperationMap is_strict_inlineable; + /*! \brief Store whether the operation needs multi-level tiling + * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */ + OperationMap needs_multi_level_tiling; + /*! \brief Store whether the operation is an output operation */ + OperationMap is_output; + /*! \brief Store the topological order of operations */ + Array ops_topo_order; + + static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer"; + TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object); +}; + +/*! + * \brief Managed reference to AccessAnalyzerNode. + * \sa AccessAnalyzerNode + */ +class AccessAnalyzer : public ObjectRef { + public: + explicit AccessAnalyzer(const Array& tensors); + + /*! + * \brief Return whether this operation is an injective operation + * (e.g., injective, broadcast and elementwise ops without reduction) + * \param op The operation + */ + TVM_DLL bool IsSimpleAccess(const te::Operation& op) const; + + /*! + * \brief Return whether this operation is strictly inlinable + * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) + * \param op The operation + */ + TVM_DLL bool IsStrictInlineable(const te::Operation& op) const; + + /*! + * \brief Return whether this operation needs multi-level tiling + * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) + * \param op The operation + */ + TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const; + + /*! + * \brief Return whether this operation is an output op + * \param op The operation + */ + TVM_DLL bool IsOutput(const te::Operation& op) const; + + /*! + * \brief Get all consumers of on operation + * \param state The current loop state + * \param op The operation + * \return The set of consumers + * \note This function propagates the relation for inlined ops + */ + TVM_DLL std::unordered_set GetConsumers( + const State& state, const te::Operation& op) const; + + /*! + * \brief Get all producers of on operation + * \param state The current loop state + * \param op The operation + * \return The set of producers + * \note This function propagates the relation for inlined ops + */ + TVM_DLL std::unordered_set GetProducers( + const State& state, const te::Operation& op) const; + + /*! + * \brief Get all direct producers of on operation + * \param op The operation + * \return The set of direct producers + * \note This function DOES NOT propagate the relation for inlined ops + */ + TVM_DLL std::unordered_set GetDirectProducers( + const te::Operation& op) const; + + /*! + * \brief Get the number of common outer iterators. + * \param op The operation + * \param target_op The target operation + * \note This function propagates the relation for chains with multiple ops. + */ + TVM_DLL int GetNumCommonOuterIterator(const te::Operation& op, + const te::Operation& target_op) const; + + /*! + * \brief Return whether two operations are elementwise-matched + * (e.g. conv2d and relu are elementwise matched) + * \note This function propagates the relation for chains with multiple ops. + */ + TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const; + + TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); +}; + +/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */ +class ComputeDAGNode : public Object { + public: + /*! + * \brief Input and output tensors. + * This is used as the input of `tvm.lower` or `tvm.build`. + */ + Array tensors; + /*! \brief All related operations in topo order. */ + Array ops; + /*! \brief The number of total float operations for this ComputeDAG. */ + double flop_ct; + /*! \brief The initial state without any transform steps. */ + State init_state; + /*! \brief The static read-write access analyzer */ + AccessAnalyzer access_analyzer; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tensors", &tensors); + v->Visit("ops", &ops); + v->Visit("flop_ct", &flop_ct); + v->Visit("init_state", &init_state); + } + + static constexpr const char* _type_key = "auto_scheduler.ComputeDAG"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); +}; + +/*! + * \brief Managed reference to ComputeDAGNode. + * \sa ComputeDAGNode + */ +class ComputeDAG : public ObjectRef { + public: + /*! \brief The constructor. + * \param tensors `te::Tensor`s for a compute declaration. + */ + TVM_DLL explicit ComputeDAG(Array tensors); + + /*! + * \brief Apply the history transform steps to get a TVM schedule. + * \param transform_steps Transform steps of a state. + * \param stages The list of stages after applying the steps. + * Pass a valid pointer if this information needs to be used outside this function. + * \param stage_to_axes The map that stores all axes for one stage. + * Pass a valid pointer if this information needs to be used outside this function. + * \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower` + * or `tvm.build`. + */ + std::pair> ApplySteps( + const Array& transform_steps, Array* stages = nullptr, + StageToAxesMap* stage_to_axes = nullptr) const; + + /*! + * \brief Print transform steps as equivalent python schedule API. + * This can be used for debugging. + * \param transform_steps Transform steps of a state. + * \return The Python schedule code. + */ + String PrintStepsAsPython(const Array& transform_steps) const; + + /*! + * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound. + * The states can lose complete bound information after some transform steps (e.g., compute_at). + * We can call this function to infer and fill all the bound information. + * This function calls TVM InferBound pass internally to get the bound. + * The returned state of this function is guaranteed to have complete bound information. + * \param state The input state. + * \return The State with complete bound information + */ + State InferBound(const State& state) const; + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); +}; + +} // namespace auto_scheduler +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ diff --git a/src/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h similarity index 91% rename from src/auto_scheduler/loop_state.h rename to include/tvm/auto_scheduler/loop_state.h index 4d6477b..4e9cb9b 100644 --- a/src/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -48,6 +48,8 @@ #ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_ #define TVM_AUTO_SCHEDULER_LOOP_STATE_H_ +#include +#include #include #include @@ -55,8 +57,6 @@ #include #include -#include "transform_step.h" - namespace tvm { namespace auto_scheduler { @@ -159,10 +159,16 @@ using IterKey = std::pair; */ class AttachMapNode : public Object { public: + struct IterKeyHash { + std::size_t operator()(const IterKey& k) const { + return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } + }; + /*! \brief A Map to store the mapping of stage to its attached iterator. */ std::unordered_map stage_to_attach_iter; /*! \brief A Map to store the mapping of iterator to the stage attached to it. */ - std::unordered_map> iter_to_attached_stages; + std::unordered_map, IterKeyHash> iter_to_attached_stages; static constexpr const char* _type_key = "auto_scheduler.AttachMap"; TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); @@ -291,14 +297,14 @@ class State : public ObjectRef { * this input. * \return The iterator result after binded. */ - Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); + TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); /*! * \brief Schedule primitive corresponds to te.parallel. * \param stage_id The index of the stage to be paralleled. * \param it The iterator to be paralleled. * \return The iterator result after parallel. */ - Iterator parallel(int stage_id, const Iterator& it); + TVM_DLL Iterator parallel(int stage_id, const Iterator& it); /*! * \brief Schedule primitive corresponds to te.unroll. * \param stage_id The index of the stage to be unrolled. @@ -307,14 +313,14 @@ class State : public ObjectRef { * skipped. * \return The iterator result after unrolled. */ - Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); + TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); /*! * \brief Schedule primitive corresponds to te.vectorize. * \param stage_id The index of the stage to be vectorized. * \param it The iterator to be vectorized. * \return The iterator result after vectorize. */ - Iterator vectorize(int stage_id, const Iterator& it); + TVM_DLL Iterator vectorize(int stage_id, const Iterator& it); /*! * \brief Schedule primitive corresponds to te.fuse. * \param stage_id The index of the stage to be fused. @@ -323,13 +329,13 @@ class State : public ObjectRef { * \note If the iterators to be fused have stages attached at them(by compute_at), the fused * result will become the new attach point. */ - Iterator fuse(int stage_id, const Array& iters); + TVM_DLL Iterator fuse(int stage_id, const Array& iters); /*! * \brief Schedule primitive corresponds to te.reorder. * \param stage_id The index of the stage to be reordered. * \param order The expected iterator order. */ - void reorder(int stage_id, const Array& order); + TVM_DLL void reorder(int stage_id, const Array& order); /*! * \brief Schedule primitive corresponds to te.split. * \param stage_id The index of the stage to be split. @@ -340,8 +346,9 @@ class State : public ObjectRef { * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner * most iterator of split results will become the new attach point. */ - Array split(int stage_id, const Iterator& it, const Array>& lengths, - bool inner_to_outer = true); + TVM_DLL Array split(int stage_id, const Iterator& it, + const Array>& lengths, + bool inner_to_outer = true); /********** Step APIs working on multiple stages **********/ @@ -355,12 +362,12 @@ class State : public ObjectRef { * bound for the newly created iterators. * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ - void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); + TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! * \brief Schedule primitive corresponds to te.compute_inline. * \param stage_id The index of the stage to be reordered. */ - void compute_inline(int stage_id); + TVM_DLL void compute_inline(int stage_id); /*! * \brief Schedule primitive corresponds to te.compute_root. * \param stage_id The index of the stage to be reordered. @@ -369,7 +376,7 @@ class State : public ObjectRef { * bound for the newly created iterators. * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ - void compute_root(int stage_id); + TVM_DLL void compute_root(int stage_id); TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); @@ -381,21 +388,11 @@ class State : public ObjectRef { // Hash and equal function for State namespace std { -/*! \brief The hash function for auto_scheduler::State. */ -template <> -struct hash<::tvm::auto_scheduler::State> { - std::size_t operator()(const ::tvm::auto_scheduler::State& state) const { - return tvm::runtime::ObjectHash()(state.ToStr()); - } -}; - /*! * \brief The equal_to function for auto_scheduler::State. - * We use the schedule result(its string format) of a state to check if two states are `euqal`. - * Equal States: 1. the transform steps are totally the same; 2. even with different steps, two - * states may still result in a same schedule. e.g. To split a axis with extent 512 to 3 parts - * [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can get a same result - * to split from outter to inner by factors [8, 16]) + * This function checkes the equality by looking at the lowered string format of states. + * If two states with different transform history have the same lowered string format, + * they will be considered being equal. */ template <> struct equal_to<::tvm::auto_scheduler::State> { @@ -405,6 +402,14 @@ struct equal_to<::tvm::auto_scheduler::State> { } }; +/*! \brief The hash function for auto_scheduler::State. */ +template <> +struct hash<::tvm::auto_scheduler::State> { + std::size_t operator()(const ::tvm::auto_scheduler::State& state) const { + return tvm::runtime::ObjectHash()(state.ToStr()); + } +}; + } // namespace std #endif // TVM_AUTO_SCHEDULER_LOOP_STATE_H_ diff --git a/src/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h similarity index 93% rename from src/auto_scheduler/measure.h rename to include/tvm/auto_scheduler/measure.h index 02d6e87..83d7c8d 100644 --- a/src/auto_scheduler/measure.h +++ b/include/tvm/auto_scheduler/measure.h @@ -23,26 +23,28 @@ * These functions are responsible for building the tvm module, uploading it to remote devices, * recording the running time costs, and checking the correctness of the output. * - * We separate the measurement into two steps: build and run. + * The measurement is separated into two steps: build and run. * A builder builds the executable binary files and a runner runs the binary files to get the * measurement results. The flow of data structures is * * `ProgramBuilder` `ProgramRunner` * `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult` * - * We implement these in python to utilize python's multiprocessing and error handling. + * The core functions is implemented in python to utilize python's multiprocessing + * and error handling (see also `python/tvm/auto_scheduler/measure.py`). + * This c++ file is just a wrapper for the python functions. */ #ifndef TVM_AUTO_SCHEDULER_MEASURE_H_ #define TVM_AUTO_SCHEDULER_MEASURE_H_ +#include +#include + #include #include #include -#include "loop_state.h" -#include "search_task.h" - namespace tvm { namespace auto_scheduler { @@ -209,7 +211,7 @@ class MeasureCallbackNode : public Object { public: /*! * \brief Callback function that will be called on measurement input/result pairs - * after measurement. + * after each measurement batch. * \param policy The current search policy. * \param inputs An Array of MeasureInput. * \param results An Array of MeasureResult. @@ -234,7 +236,7 @@ class MeasureCallback : public ObjectRef { /*! \brief ProgramBuilder that builds the programs */ class ProgramBuilderNode : public Object { public: - /*! \brief The number of tasks to run in parallel */ + /*! \brief The number of build processes to run in parallel */ int n_parallel; /*! \brief Timeout of a build */ int timeout; @@ -323,15 +325,15 @@ class LocalBuilder : public ProgramBuilder { * \brief The constructor. * \param timeout The timeout limit (in second) for each build thread. * This will be used in a wrapper of the multiprocessing.Process.join(). - * \param n_parallel Number of threads used to build in parallel. - * \param build_func The name of registered build function. + * \param n_parallel The number of threads used to build in parallel. + * \param build_func The name of the registered build function. */ LocalBuilder(int timeout, int n_parallel, const String& build_func); TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode); }; -/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ +/*! \brief LocalRunner that uses local CPU/GPU to measure the time cost of programs */ class LocalRunnerNode : public ProgramRunnerNode { public: Array Run(const Array& inputs, @@ -373,13 +375,12 @@ class RPCRunnerNode : public ProgramRunnerNode { String key; /*! \brief The host address of the RPC Tracker. */ String host; - /*! \brief The port of RPC Tracker. */ + /*! \brief The port of the RPC Tracker. */ int port; /*! \brief The priority of this run request, larger is more prior. */ int priority; /*! \brief The number of tasks run in parallel. */ int n_parallel; - /*! \brief The number of times to run the generated code for taking average. */ Array Run(const Array& inputs, const Array& build_results, int verbose) final; @@ -395,10 +396,11 @@ class RPCRunnerNode : public ProgramRunnerNode { class RPCRunner : public ProgramRunner { public: /*! - * \brief The constructor. + * \brief The constructor. See the corresponding class in python/tvm/auto_scheduler/measure.py + * for more detailed parameter explaination. * \param key The key of the device registered in the RPC tracker. * \param host The host address of the RPC Tracker. - * \param prot The port of RPC Tracker. + * \param port The port of RPC Tracker. * \param priority The priority of this run request, larger is more prior. * \param n_parallel The number of tasks run in parallel. * \param timeout Timeout of a run. @@ -415,7 +417,7 @@ class RPCRunner : public ProgramRunner { /*! * \brief Measurer that measures the time costs of tvm programs - * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */ + * This class combines ProgramBuilder and ProgramRunner and provides a simpler API */ class ProgramMeasurerNode : public Object { public: /*! \brief Measured programs counter. */ @@ -483,7 +485,7 @@ class ProgramMeasurer : public ObjectRef { * \param callbacks MeasureCallback to be called after each measure batch. * \param verbose Verbosity level. 0 for silent, 1 to output information during program * measuring. - * \param max_continous_error The number of max continuous error. + * \param max_continous_error The number of allowed maximum continuous error. */ ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, Optional> callbacks, int verbose, diff --git a/src/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h similarity index 83% rename from src/auto_scheduler/measure_record.h rename to include/tvm/auto_scheduler/measure_record.h index 1cfeab0..fa8fe2b 100644 --- a/src/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -18,26 +18,26 @@ */ /*! - * \file auto_scheduler/measure_record.h - * \brief Json serialization format for dumping and loading tuning records. + * \file tvm/auto_scheduler/measure_record.h + * \brief Json serialization format for dumping and loading measurement records. */ #ifndef TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_ #define TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_ +#include + #include #include #include -#include "measure.h" - namespace tvm { namespace auto_scheduler { /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { public: - /*! \brief File name for this callback to write log to. */ + /*! \brief The name of output file. */ String filename; void Callback(const SearchPolicy& policy, const Array& inputs, @@ -55,7 +55,7 @@ class RecordToFile : public MeasureCallback { public: /*! * \brief The constructor. - * \param filename File name for this callback to write log. + * \param filename The name of output file */ explicit RecordToFile(String filename); @@ -65,7 +65,7 @@ class RecordToFile : public MeasureCallback { /*! \brief Log reader to load step logs from a file.*/ class RecordReaderNode : public Object { public: - /*! \brief File name for this reader to load log from. */ + /*! \brief The name of input file. */ String filename; /*! \brief The reading file stream. */ std::ifstream infile; @@ -92,7 +92,7 @@ class RecordReaderNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object); private: - /*! \brief A string object to store the next line. */ + /*! \brief A string storing the current line. */ std::string cur_line_; }; @@ -104,7 +104,7 @@ class RecordReader : public ObjectRef { public: /*! * \brief The constructor. - * \param filename File name for this callback to write log. + * \param filename The name of input file */ explicit RecordReader(String filename); @@ -112,7 +112,7 @@ class RecordReader : public ObjectRef { }; /*! - * \brief Write measure records to an output stream. + * \brief Append measure records to an output stream. * \param os A pointer to a output stream. * \param inputs The MeasureInputs to be written. * \param results The MeasureResults to be written. @@ -122,10 +122,10 @@ void WriteMeasureRecords(std::ostream* os, const Array& inputs, /*! * \brief Read one measure record from a string. - * \param str The record string to be extract. - * \param inp A pointer to a MeasureInputNode, this is used as output. - * \param res A pointer to a MeasureResultNode, this is used as output. - * \param log_version A pointer to a log version string. + * \param str The record string to be parsed. + * \param inp A pointer to a MeasureInputNode used to store the return value. + * \param res A pointer to a MeasureResultNode used to store the return value. + * \param log_version A pointer to a string used to store the log version. */ void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, std::string* log_version); diff --git a/src/auto_scheduler/search_policy/search_policy.h b/include/tvm/auto_scheduler/search_policy.h similarity index 79% rename from src/auto_scheduler/search_policy/search_policy.h rename to include/tvm/auto_scheduler/search_policy.h index 70f94ad..457aca1 100644 --- a/src/auto_scheduler/search_policy/search_policy.h +++ b/include/tvm/auto_scheduler/search_policy.h @@ -18,11 +18,11 @@ */ /*! - * \file auto_scheduler/search_policy/search_policy.h + * \file tvm/auto_scheduler/search_policy.h * \brief The base class of search policies, including the abstract definition of search policy and * other supporting data structures. * - * The basic schedule search process for TVM Auto-scheduler is design to be: + * The basic schedule search process for the auto-scheduler is design to be: * `Program sampling` -> `Performance Tuning`. * * In `Program sampling`, we use some predefined precise or heuristic rules to generate several @@ -31,7 +31,7 @@ * * Candidate schedules are measured against the specific hardware target. * - * \note Adding a new search policy. + * \note How to add a new search policy. * In design, there's no need for users to implement their own search policy, our formal search * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule * mechanism will be provided to enable user-defined template search to serve the same functionality @@ -48,16 +48,15 @@ * during the search process. */ -#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_ -#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_ +#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ +#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ +#include #include #include #include -#include "../search_task.h" - namespace tvm { namespace auto_scheduler { @@ -110,16 +109,16 @@ class SearchPolicyNode : public Object { /*! * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state - * get during the search process. - * \param task The SearchTask or workload key for the computation declaration - * \param num_measure_trials Total schedules to be tried during this search. - * \param early_stopping Early stop if no better schedule is found. - * \param num_measures_per_round Max measure batch in one search round. + * found during the search. + * \param task The SearchTask for the computation declaration + * \param num_measure_trials The number of total measurement trials. + * \param early_stopping Stops the tuning early if no improvement after n measurements. + * \param num_measures_per_round The number of programs to be measured at each search round. * \param verbose Verbose level. 0 for silent, 1 to output information during schedule * search. - * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside. + * \param measurer A ProgramMeasurer to build and measure programs * \param pre_search_callbacks SearchCallback to be called before schedule search. - * \return The best state get. + * \return The best state found. */ virtual State Search(SearchTask task, int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramMeasurer measurer, @@ -137,16 +136,12 @@ class SearchPolicyNode : public Object { protected: /*! * \brief The set of already measured states. - * During the schedule search process, we may generate `equal states` through different search - * branches. (Equal States: 1. the transform steps are totally the same; 2. even with different - * steps, two states may still result in a same schedule. e.g. To split a axis with extent 512 - * to 3 parts [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can - * get a same result to split from outter to inner by factors [8, 16]) * We store the string format of a state for redundancy check. This is used to make sure a * measured state will never be measured again. */ std::unordered_set measured_states_set_; - /*! \brief The array of already measured states. This can be used in evolutionary search. */ + /*! \brief The array of already measured states. + * The good states can be used as the initial population in evolutionary search. */ std::vector measured_states_vector_; /*! \brief The throughputs of already measured states */ std::vector measured_states_throughputs_; @@ -164,4 +159,4 @@ class SearchPolicy : public ObjectRef { } // namespace auto_scheduler } // namespace tvm -#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_ +#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ diff --git a/src/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h similarity index 97% rename from src/auto_scheduler/search_task.h rename to include/tvm/auto_scheduler/search_task.h index ca31350..85154b5 100644 --- a/src/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -25,16 +25,15 @@ #ifndef TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ #define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ +#include #include -#include "compute_dag.h" - namespace tvm { namespace auto_scheduler { class HardwareParams; -/*! \brief The parameters of target hardware used to guide the search process of SearchPolicy. */ +/*! \brief The parameters of target hardware used to guide the SearchPolicy. */ class HardwareParamsNode : public Object { public: /*! \brief The number of cores. */ diff --git a/src/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h similarity index 98% rename from src/auto_scheduler/transform_step.h rename to include/tvm/auto_scheduler/transform_step.h index ce3ca50..b23137a 100644 --- a/src/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -19,10 +19,10 @@ /*! * \file auto_scheduler/transform_step.h - * \brief Transformation steps. For each schedule primitive, there is a corresponding transform - * step. + * \brief Transformation steps. These steps are used to manipulate the LoopState. + * They are similar to the schedule primitives in te::Stage. * - * \note To add a new transform step: + * \note How to add a new transform step: * Take fuse step for example: * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first * construction function `FuseStep::FuseStep()` in `transform_steps.cc`. @@ -51,8 +51,6 @@ #include #include -#include "utils.h" - namespace tvm { namespace auto_scheduler { @@ -187,7 +185,6 @@ Step StepReadFromRecord(dmlc::JSONReader* reader); * \param step The step to be applied to State. * \param state A mutable pointer to State. * \param dag The original ComputeDAG of this state. - * \return The iterator result after annotate. */ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); @@ -209,7 +206,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes); -/********** Primitives working on single stage **********/ +/********** Steps working on single stage **********/ /*! * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. @@ -478,7 +475,7 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; -/********** Primitives working on multiple stages **********/ +/********** Steps working on multiple stages **********/ /*! \brief Compute at step that corresponds to te::Stage::compute_at */ class ComputeAtStepNode : public StepNode { diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py index d45dbf8..52aa62b 100644 --- a/python/tvm/auto_scheduler/auto_schedule.py +++ b/python/tvm/auto_scheduler/auto_schedule.py @@ -57,7 +57,7 @@ class HardwareParams(Object): @tvm._ffi.register_object("auto_scheduler.SearchTask") class SearchTask(Object): - """ The computation information and hardware parameters for a specific schedule search task. + """ The computation information and hardware parameters for a schedule search task. Parameters ---------- @@ -158,9 +158,6 @@ class TuningOptions(Object): def auto_schedule(task, search_policy='default', tuning_options=None): """ Do auto scheduling for a computation declaration. - The task parameter can be a `string` as workload_key, or directly - passing a `SearchTask` as input. - Parameters ---------- task : SearchTask diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 36c2037..045720a 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -95,7 +95,7 @@ def make_workload_key(func, args): Returns ------- - workload_key : Str + workload_key : str The workload key of the function. """ global WORKLOAD_FUNC_REGISTRY diff --git a/src/auto_scheduler/auto_schedule.cc b/src/auto_scheduler/auto_schedule.cc index b515b3a..c537ca7 100644 --- a/src/auto_scheduler/auto_schedule.cc +++ b/src/auto_scheduler/auto_schedule.cc @@ -24,8 +24,7 @@ * schedule after search process. */ -#include "auto_schedule.h" - +#include #include namespace tvm { diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index d81dff6..68d1bb4 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -22,12 +22,13 @@ * \brief Compute declaration graph and its related analysis tools. */ -#include "compute_dag.h" - +#include +#include #include #include #include #include +#include #include #include @@ -36,7 +37,7 @@ #include #include -#include "loop_state.h" +#include "../arith/pattern_match.h" #include "utils.h" namespace tvm { @@ -44,6 +45,10 @@ namespace auto_scheduler { using namespace tvm::tir; +template +using OperationMap = AccessAnalyzerNode::OperationMap; +using OperationSet = std::unordered_set; + TVM_REGISTER_NODE_TYPE(ComputeDAGNode); // Topo-sort ops from tensors according to their read-write relations. @@ -114,7 +119,416 @@ Array TopoSortOps(const Array& tensors) { return ops; } -// Estimate number of float operations in an expression +// Extract all tensor accesses in an expr +class ReadAccessExtractor : public StmtExprVisitor { + public: + void Extract(PrimExpr expr) { this->VisitExpr(expr); } + + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::if_then_else())) { + has_branch = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const ProducerLoadNode* op) final { + read_access[Downcast(op->producer)->op].emplace_back(op->indices.begin(), + op->indices.end()); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const IfThenElseNode* op) final { + has_branch = true; + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const SelectNode* op) final { + has_branch = true; + StmtExprVisitor::VisitExpr_(op); + } + + // All read accesses to all operations + // The innermost vector stores mulit-dimentional indices. + // The middle vector stores possible multiple accesses + OperationMap>> read_access; + // Whether this expression has branch + bool has_branch{false}; +}; + +// Returns whether the expr equals to the var with an optional const shift +bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) { + arith::PVar x; + arith::PVar c; + + if (((x + c).Match(expr) || (x - c).Match(expr) || (c + x).Match(expr) || x.Match(expr)) && + x.Eval().same_as(var)) { + return true; + } + return false; +} + +// Return whether the access to an operation is a simple access +// (i.e. all index is just a variable with an optional constant shift) +// For example, A[i][j], A[i+1][j] are simple accesses but A[i][j+i] is not. +bool IsSimpleAccess(const te::Operation& op, const std::vector& indices, + bool* axis_missing, bool* axis_duplicated, bool* same_order) { + auto cop = op.as(); + if (cop == nullptr) { + return false; + } + + std::vector index_to_var_idx; + std::vector var_idx_ct(cop->axis.size(), 0); + + for (const auto& expr : indices) { + if (!is_const_int(expr)) { + bool found = false; + for (size_t i = 0; i < cop->axis.size(); ++i) { + if (IsConstShiftEqual(cop->axis[i]->var, expr)) { + index_to_var_idx.push_back(i); + var_idx_ct[i]++; + found = true; + break; + } + } + if (!found) { + return false; + } + } + } + + *axis_missing = false; // Some axes are missing + *axis_duplicated = false; // Some axes appear more than once + *same_order = true; // The axis order is the same as op->axis + for (int ct : var_idx_ct) { + if (ct == 0) { + *axis_missing = true; + } else if (ct > 1) { + *axis_duplicated = true; + } + } + for (size_t i = 1; i < index_to_var_idx.size(); ++i) { + if (index_to_var_idx[i] < index_to_var_idx[i - 1]) { + *same_order = false; + break; + } + } + + return true; +} + +// Gather all VarNodes in an expr +void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { + PostOrderVisit(expr, [&vars](const ObjectRef& node) { + if (const VarNode* op = node.as()) { + vars->insert(op); + } + }); +} + +// Check whether an expr has expensive operations (e.g. exp) +bool HasExpensiveOp(const PrimExpr& expr) { + bool found = false; + PostOrderVisit(expr, [&found](const ObjectRef& node) { + if (const CallNode* op = node.as()) { + if (op->op.as()->name == "tir.exp") { + found = true; + } + } + }); + return found; +} + +AccessAnalyzer::AccessAnalyzer(const Array& tensors) { + auto node = make_object(); + OperationMap has_branch; + + // Get all ops in topological order + node->ops_topo_order = TopoSortOps(tensors); + + arith::Analyzer analyzer; + + // Build read & write access map + for (const auto& op : node->ops_topo_order) { + if (op->IsInstance()) { + node->read_from[op] = OperationMap>>(); + } else if (auto cop = op.as()) { + ReadAccessExtractor extractor; + for (const auto& exp : cop->body) { + extractor.Extract(exp); + } + + // read_by and read_from map + for (const auto& iter : extractor.read_access) { + std::vector>& accesses = node->read_by[iter.first][op]; + accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end()); + } + + node->read_from[op] = std::move(extractor.read_access); + has_branch[op] = extractor.has_branch; + + // compute number of common outer iterators + for (const auto& pair : node->read_from[op]) { + const te::Operation& producer = pair.first; + const std::vector>& access_list = pair.second; + const Array& output_shape = op->output_shape(0); + const Array& producer_shape = producer->output_shape(0); + + int n_common; + for (n_common = 0; + n_common < static_cast(std::min(output_shape.size(), producer_shape.size())); + n_common++) { + if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) { + break; + } + + bool injective = true; + for (const auto& access : access_list) { + if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) { + injective = false; + break; + } + } + + if (!injective) { + break; + } + } + + node->num_common_outer_iterators[op][producer] = n_common; + node->num_common_outer_iterators[producer][op] = n_common; + } + } else { + LOG(FATAL) << "Invalid op: " << op; + } + } + + // Do some static analysis on ComputeOps + for (const auto& op : node->ops_topo_order) { + if (op->IsInstance()) { + node->is_simple_access[op] = true; + node->needs_multi_level_tiling[op] = false; + node->is_strict_inlineable[op] = false; + node->is_output[op] = false; + } else if (auto cop = op.as()) { + // check whether this op is element-wise and strict-inlineable + bool is_simple_access = true; + bool is_strict_inlineable = true; + + bool axis_missing, axis_duplicated, same_order; + for (const auto& pair : node->read_from[op]) { + const std::vector>& access_list = pair.second; + for (const auto& access : access_list) { + if (!auto_scheduler::IsSimpleAccess(op, access, &axis_missing, &axis_duplicated, + &same_order)) { + is_simple_access = false; + is_strict_inlineable = false; + break; + } + if (!same_order || axis_duplicated) { + // do not strictly inline transpose + is_strict_inlineable = false; + } + } + if (!is_simple_access) { + break; + } + } + + // don't strictly inline expensive op (e.g. exp) + bool has_expensive_op = false; + for (const auto& expr : cop->body) { + has_expensive_op |= HasExpensiveOp(expr); + } + if (has_expensive_op || has_branch[op]) { + is_strict_inlineable = false; + } + + node->is_simple_access[op] = is_simple_access; + node->is_strict_inlineable[op] = is_strict_inlineable; + + // check whether the op needs multi-level tiling + bool needs_multi_level_tiling = false; + int n_missing = 0; + + for (const auto& pair : node->read_from[op]) { + const std::vector>& access_list = pair.second; + std::unordered_set vars; + for (const std::vector& access : access_list) { + for (const PrimExpr& expr : access) { + GatherVars(expr, &vars); + } + } + + for (const auto& axis : cop->axis) { + if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) { + n_missing++; + break; + } + } + + if (n_missing >= 2 || (n_missing >= 1 && !cop->reduce_axis.empty())) { + needs_multi_level_tiling = true; + break; + } + } + + node->needs_multi_level_tiling[op] = needs_multi_level_tiling; + + // check whether the op is output + node->is_output[op] = node->read_by[op].empty(); + } else { + LOG(FATAL) << "Invalid op" << op; + } + } + + data_ = std::move(node); +} + +bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation& op) const { + return operator->()->needs_multi_level_tiling.at(op); +} + +bool AccessAnalyzer::IsOutput(const te::Operation& op) const { + return operator->()->is_output.at(op); +} + +bool AccessAnalyzer::IsSimpleAccess(const te::Operation& op) const { + return operator->()->is_simple_access.at(op); +} + +bool AccessAnalyzer::IsStrictInlineable(const te::Operation& op) const { + return operator->()->is_strict_inlineable.at(op); +} + +OperationSet AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op) const { + OperationSet inlined_ops; + for (const auto& stage : state->stages) { + if (stage->compute_at == ComputeAtKind::kInlined) { + inlined_ops.insert(stage->op); + } + } + + OperationSet consumers; + std::function collect; + collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) { + for (const auto& iter : operator->()->read_by.at(op)) { + if (inlined_ops.count(iter.first)) { + collect(iter.first); + } else { + consumers.insert(iter.first); + } + } + }; + + collect(op); + return consumers; +} + +OperationSet AccessAnalyzer::GetDirectProducers(const te::Operation& op) const { + OperationSet producers; + for (const auto& iter : operator->()->read_from.at(op)) { + producers.insert(iter.first); + } + return producers; +} + +OperationSet AccessAnalyzer::GetProducers(const State& state, const te::Operation& op) const { + OperationSet inlined_ops; + for (const auto& stage : state->stages) { + if (stage->compute_at == ComputeAtKind::kInlined) { + inlined_ops.insert(stage->op); + } + } + + OperationSet producers; + std::function collect; + collect = [this, &collect, &inlined_ops, &producers](const te::Operation& op) { + for (const auto& iter : operator->()->read_from.at(op)) { + if (inlined_ops.count(iter.first)) { + collect(iter.first); + } else { + producers.insert(iter.first); + } + } + }; + + collect(op); + return producers; +} + +int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op, + const te::Operation& target_op) const { + int ret = INT32_MAX; + bool meet = false; + + std::function traverse; + traverse = [this, &traverse, &target_op, &ret, &meet](const te::Operation& cur_op, int cur_num) { + if (cur_op == target_op) { + ret = std::min(ret, cur_num); + meet = true; + return; + } + + for (const auto& iter : operator->()->read_by.at(cur_op)) { + traverse( + iter.first, + std::min(cur_num, operator->()->num_common_outer_iterators.at(cur_op).at(iter.first))); + } + }; + + traverse(op, op->output_shape(0).size()); + return meet ? ret : 0; +} + +bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, + const te::Operation& target_op) const { + te::Operation cur_op = op; + while (cur_op != target_op) { + const AccessAnalyzerNode::OperationMap>>& map = + operator->()->read_by.at(cur_op); + + if (map.size() != 1) { + return false; + } + te::Operation next_op = map.begin()->first; + + // Check condition 1: They have the same output size + auto p_cur = cur_op.as(); + auto p_next = next_op.as(); + if (p_cur == nullptr || p_next == nullptr) { + return false; + } + + Array output_shape = p_cur->output_shape(0); + for (int i = 1; i < p_cur->num_outputs(); ++i) { + if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) { + return false; + } + } + for (int i = 0; i < p_next->num_outputs(); ++i) { + if (!IntArrayEqual(p_next->output_shape(i), output_shape)) { + return false; + } + } + + // Check condition 2: The read is elementwise + const std::vector> reads = map.begin()->second; + bool is_simple_access, axis_missing, axis_duplicated, same_order; + for (const auto& read : reads) { + is_simple_access = auto_scheduler::IsSimpleAccess(next_op, read, &axis_missing, + &axis_duplicated, &same_order); + if (!is_simple_access || axis_missing || axis_duplicated || !same_order) { + return false; + } + } + + cur_op = std::move(next_op); + } + return true; +} + +// Estimate the number of float operations in an expression class FlopEstimator : public ExprFunctor { public: double EstimateFlop(const Array& ops) { @@ -126,6 +540,7 @@ class FlopEstimator : public ExprFunctor { fail_ = true; break; } + cur_type_code_ = pop->output_dtype(0).code(); double op_per_element = 0; for (const auto& x : pop->body) { op_per_element += VisitExpr(x); @@ -171,10 +586,17 @@ class FlopEstimator : public ExprFunctor { std::max(VisitExpr(op->true_value), VisitExpr(op->false_value)); } -#define VisitBinary(Node) \ - double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); } -#define VisitUnary(Node) \ - double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a); } +#define VisitBinary(Node) \ + double VisitExpr_(const Node* op) final { \ + double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \ + return base + VisitExpr(op->a) + VisitExpr(op->b); \ + } + +#define VisitUnary(Node) \ + double VisitExpr_(const Node* op) final { \ + double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \ + return base + VisitExpr(op->a); \ + } VisitBinary(AddNode); VisitBinary(SubNode); @@ -210,12 +632,14 @@ class FlopEstimator : public ExprFunctor { private: bool fail_{false}; + int cur_type_code_; }; ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); node->tensors = std::move(tensors); - node->ops = TopoSortOps(node->tensors); + node->access_analyzer = AccessAnalyzer(node->tensors); + node->ops = node->access_analyzer->ops_topo_order; node->flop_ct = FlopEstimator().EstimateFlop(node->ops); node->init_state = State(node->ops); data_ = std::move(node); diff --git a/src/auto_scheduler/compute_dag.h b/src/auto_scheduler/compute_dag.h deleted file mode 100644 index 2417d72..0000000 --- a/src/auto_scheduler/compute_dag.h +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file auto_scheduler/compute_dag.h - * \brief The TVM Auto-scheduler computational graph and related program analyses. - * - * We convert a compute declaration described by `tvm.compute` (could be a single operator or a - * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, - * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the - * total float operation count, consumer/producer relations of each operation stage, whether an - * operation stage should be tiled/compute inlined ...). These analyses can help the search policy - * to make decisions during search process. - * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and - * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing - * `LoopState` with extra information got from TVM schedule ...). - */ - -#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ -#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ - -#include - -#include - -#include "loop_state.h" - -namespace tvm { -namespace auto_scheduler { - -/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */ -class ComputeDAGNode : public Object { - public: - /*! - * \brief Input and output tensors. - * This is used as the input of `tvm.lower` or `tvm.build`. - */ - Array tensors; - /*! \brief All related operations in topo order. */ - Array ops; - /*! \brief Number of total float operations for this ComputeDAG. */ - double flop_ct; - /*! \brief The initial state without any transform steps. */ - State init_state; - // TODO(merrymercy): Add more analyses later. - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("tensors", &tensors); - v->Visit("ops", &ops); - v->Visit("flop_ct", &flop_ct); - v->Visit("init_state", &init_state); - } - - static constexpr const char* _type_key = "auto_scheduler.ComputeDAG"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); -}; - -/*! - * \brief Managed reference to ComputeDAGNode. - * \sa ComputeDAGNode - */ -class ComputeDAG : public ObjectRef { - public: - /*! \brief The constructor. - * \param tensors `te::Tensor`s for a compute declaration. - */ - explicit ComputeDAG(Array tensors); - - /*! - * \brief Apply the history transform steps from a State to get a TVM schedule. - * \param transform_steps Transform steps of a state. - * \param stages A pointer to a `te::Stage` Array, default to be nullptr. - * Pass a valid pointer if these information needs to be used outside this function. - * \param stage_to_axes A pointer to a StageToAxesMap, default to be nullptr. - * Pass a valid pointer if these information needs to be used outside this function. - * \return A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. - */ - std::pair> ApplySteps( - const Array& transform_steps, Array* stages = nullptr, - StageToAxesMap* stage_to_axes = nullptr) const; - - /*! - * \brief Print transform steps as equivalent python schedule API. - * This can be used for debugging. - * \param transform_steps Transform steps of a state. - * \return The Python schedule code. - */ - String PrintStepsAsPython(const Array& transform_steps) const; - - /*! - * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound. - * The states can lose complete bound information after some transform steps (e.g., compute_at). - * We can call this function to infer and fill all the bound information. - * This function calls TVM InferBound pass internally to get the bound. - * The returned state of this function is guaranteed to have complete iterator extent information. - * \param state The state to. - * \return The State after inferbound. - */ - State InferBound(const State& state) const; - - TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); -}; - -} // namespace auto_scheduler -} // namespace tvm - -#endif // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index bfe5478..35d899a 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -23,14 +23,13 @@ * see auto_scheduler/loop_state.h for more explanation. */ -#include "loop_state.h" - +#include +#include #include #include #include -#include "transform_step.h" #include "utils.h" namespace tvm { diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc index 6198f60..e249f7b 100644 --- a/src/auto_scheduler/measure.cc +++ b/src/auto_scheduler/measure.cc @@ -22,8 +22,7 @@ * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. */ -#include "measure.h" - +#include #include #include diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 39f9ad8..02f244f 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -22,9 +22,10 @@ * \brief Json serialization format for dumping and loading tuning records. */ -#include "measure_record.h" - #include +#include +#include +#include #include #include @@ -33,8 +34,6 @@ #include #include -#include "loop_state.h" -#include "transform_step.h" #include "utils.h" // Json serialization handler for MeasureInput, MeasureResult diff --git a/src/auto_scheduler/search_policy/empty_policy.cc b/src/auto_scheduler/search_policy/empty_policy.cc index 1886203..4c85af4 100644 --- a/src/auto_scheduler/search_policy/empty_policy.cc +++ b/src/auto_scheduler/search_policy/empty_policy.cc @@ -24,10 +24,9 @@ #include "empty_policy.h" +#include #include -#include "../measure.h" - namespace tvm { namespace auto_scheduler { diff --git a/src/auto_scheduler/search_policy/empty_policy.h b/src/auto_scheduler/search_policy/empty_policy.h index 4ccc9c1..ef7d38d 100644 --- a/src/auto_scheduler/search_policy/empty_policy.h +++ b/src/auto_scheduler/search_policy/empty_policy.h @@ -26,8 +26,8 @@ #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_ #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_ -#include "../loop_state.h" -#include "search_policy.h" +#include +#include namespace tvm { namespace auto_scheduler { diff --git a/src/auto_scheduler/search_policy/search_policy.cc b/src/auto_scheduler/search_policy/search_policy.cc index fba5155..764b0a7 100644 --- a/src/auto_scheduler/search_policy/search_policy.cc +++ b/src/auto_scheduler/search_policy/search_policy.cc @@ -22,8 +22,7 @@ * \brief The base class of search policies. */ -#include "search_policy.h" - +#include #include namespace tvm { diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 912a310..9cc21f2 100644 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -22,8 +22,7 @@ * \brief Meta information and hardware parameters for a search task. */ -#include "search_task.h" - +#include #include #include diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 6c672a5..b1b3b94 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -19,12 +19,12 @@ /*! * \file auto_scheduler/transform_step.cc - * \brief Transformation steps. For each schedule primitive, there is a corresponding transform - * step. + * \brief Transformation steps. These steps are used to manipulate the LoopState. + * They are similar to the schedule primitives in te::Stage. */ -#include "transform_step.h" - +#include +#include #include #include @@ -32,7 +32,6 @@ #include #include -#include "loop_state.h" #include "utils.h" namespace tvm { @@ -80,6 +79,7 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { } void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { + // We need this runtime dispatcher because different steps have different function signatures if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { @@ -101,6 +101,7 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes) { + // We need this runtime dispatcher because different steps have different function signatures if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -122,6 +123,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes) { + // We need this runtime dispatcher because different steps have different function signatures if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -142,7 +144,7 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ""; } -/********** Primitives working on single stage **********/ +/********** Steps working on single stage **********/ /********** Annotation **********/ AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) { @@ -741,7 +743,7 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -/********** Primitives working on multiple stages **********/ +/********** Steps working on multiple stages **********/ /********** Compute At **********/ ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index de800da..da5032e 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -128,6 +128,24 @@ inline std::vector IntArrayToVector( return out; } +/*! \brief Return whether two int arrays are elementwise-equal */ +inline bool IntArrayEqual(const Array& arr1, const Array& arr2) { + if (arr1.size() != arr2.size()) { + return false; + } + + for (size_t i = 0; i < arr1.size(); ++i) { + auto int1 = arr1[i].as(); + auto int2 = arr2[i].as(); + CHECK(int1 != nullptr); + CHECK(int2 != nullptr); + if (int1->value != int2->value) { + return false; + } + } + return true; +} + /********** Utilities for TVM Containers / ByteArray **********/ /*! \brief Compute mean of a FloatImm array */ inline double FloatArrayMean(const Array& float_array) { diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc new file mode 100644 index 0000000..8526605 --- /dev/null +++ b/tests/cpp/auto_scheduler_test.cc @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include + +// Compute declaration for test +tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, int CI, int CO, + int kernel_size, int strides, int padding, + int dilation = 1) { + using namespace tvm; + using namespace tvm::te; + + Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); + Tensor kernel = placeholder({CO, CI, kernel_size, kernel_size}, DataType::Float(32), "Kernel"); + Tensor bias = placeholder({CO, 1, 1}, DataType::Float(32), "Bias"); + Tensor bn_scale = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_scale"); + Tensor bn_offset = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_offset"); + + int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; + int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; + + const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, strides, strides); + CHECK(conv->shape[2].as()->value == OH); + CHECK(conv->shape[3].as()->value == OW); + + const auto& bias_add = compute( + {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { return conv[i][j][k][l] + bias[j][0][0]; }, + "Bias_add"); + const auto& bn_mul = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { return bias_add[i][j][k][l] * bn_scale[j][0][0]; }, + "Bn_mul"); + const auto& bn_add = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { return bn_mul[i][j][k][l] + bn_offset[j][0][0]; }, + "Bn_add"); + const auto& out = topi::relu(bn_add); + + return {data, kernel, bias, bn_scale, bn_offset, out}; +} + +using namespace tvm::auto_scheduler; + +// Test Access Analyzer +TEST(ComputeDAG, AccessAnalyzer) { + const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); + const auto& dag = tvm::auto_scheduler::ComputeDAG(tensors); + State s0 = dag->init_state; + + int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; + int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10; + + std::set needs_multi_level_tiling = {conv}; + for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { + if (needs_multi_level_tiling.count(stage_id)) { + CHECK(dag->access_analyzer.NeedsMultiLevelTiling(dag->ops[stage_id])); + } else { + CHECK(!dag->access_analyzer.NeedsMultiLevelTiling(dag->ops[stage_id])); + } + } + + std::set is_simple_access = {data, padding, kernel, bias, bias_add, + bn_scale, bn_mul, bn_offset, bn_add, relu}; + for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { + if (is_simple_access.count(stage_id)) { + CHECK(dag->access_analyzer.IsSimpleAccess(dag->ops[stage_id])); + } else { + CHECK(!dag->access_analyzer.IsSimpleAccess(dag->ops[stage_id])); + } + } + + std::set is_strictly_inlinable = {bias_add, bn_mul, bn_add, relu}; + for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { + if (is_strictly_inlinable.count(stage_id)) { + CHECK(dag->access_analyzer.IsStrictInlineable(dag->ops[stage_id])); + } else { + CHECK(!dag->access_analyzer.IsStrictInlineable(dag->ops[stage_id])); + } + } + + std::set is_output = {relu}; + for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { + if (is_output.count(stage_id)) { + CHECK(dag->access_analyzer.IsOutput(dag->ops[stage_id])); + } else { + CHECK(!dag->access_analyzer.IsOutput(dag->ops[stage_id])); + } + } + + CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[conv], dag->ops[bias_add]), 4); + CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[conv], dag->ops[relu]), 4); + CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[data], dag->ops[relu]), 1); + + CHECK(dag->access_analyzer.ElementWiseMatch(dag->ops[conv], dag->ops[bias_add])); + CHECK(dag->access_analyzer.ElementWiseMatch(dag->ops[conv], dag->ops[relu])); + CHECK(!dag->access_analyzer.ElementWiseMatch(dag->ops[data], dag->ops[padding])); + + std::unordered_set op_set; + { + std::vector> consumer_list = { + {data, padding}, {padding, conv}, {kernel, conv}, {conv, bias_add}, + {bias, bias_add}, {bias_add, bn_mul}, {bn_scale, bn_mul}, {bn_mul, bn_add}, + {bn_offset, bn_add}, {bn_add, relu}}; + for (const auto& pair : consumer_list) { + op_set = dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op); + CHECK_EQ(op_set.size(), 1); + CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = {{padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; + for (const auto& pair : producer_list) { + op_set = dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op); + CHECK_EQ(op_set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(op_set.count(s0->stages[target]->op)); + } + } + } + + s0.compute_inline(bn_add); + s0.compute_inline(bn_mul); + s0.compute_inline(bias_add); + s0.compute_inline(padding); + { + std::vector> consumer_list = {{data, conv}, {kernel, conv}, {conv, relu}}; + for (const auto& pair : consumer_list) { + op_set = dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op); + CHECK_EQ(op_set.size(), 1); + CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = {{padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; + for (const auto& pair : producer_list) { + op_set = dag->access_analyzer.GetDirectProducers(s0->stages[pair.first]->op); + CHECK_EQ(op_set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(op_set.count(s0->stages[target]->op)); + } + } + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index 4934463..d9c24b9 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -17,10 +17,10 @@ """Test ComputeDAG (replay, infer bound)""" -import tvm +import tvm, topi from tvm import auto_scheduler, te -from test_auto_scheduler_common import get_tiled_matmul +from test_auto_scheduler_common import get_tiled_matmul, matmul_auto_scheduler_test def test_apply_steps(): @@ -36,8 +36,19 @@ def test_infer_bound(): def test_estimate_flop(): - dag, s = get_tiled_matmul() - assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5 + N = 512 + A, B, C = matmul_auto_scheduler_test(N, N, N) + dag = auto_scheduler.ComputeDAG([A, B, C]) + assert abs(dag.flop_ct - 2 * N ** 3) < 0.5 + + D = topi.nn.relu(C) + dag = auto_scheduler.ComputeDAG([A, B, D]) + assert abs(dag.flop_ct - 2 * N ** 3 - N * N) < 0.5 + + # should not count the comparison operations in padding + D = topi.nn.pad(C, [1, 1]) + dag = auto_scheduler.ComputeDAG([A, B, D]) + assert abs(dag.flop_ct - 2 * N ** 3) < 0.5 if __name__ == "__main__": -- 2.7.4