*/
/*!
- * \file auto_scheduler/auto_schedule.h
- * \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get
- * schedule search requirements from upper level (Python API), and returns a high performance
- * schedule after search process.
+ * \file tvm/auto_scheduler/auto_schedule.h
+ * \brief The user interface of the auto scheduler.
*/
#ifndef TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_
#define TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_
-#include <utility>
+#include <tvm/auto_scheduler/measure.h>
+#include <tvm/auto_scheduler/search_policy.h>
-#include "measure.h"
-#include "search_policy/search_policy.h"
+#include <utility>
namespace tvm {
namespace auto_scheduler {
/*! \brief Tuning and measurement options. */
class TuningOptionsNode : public Object {
public:
- /*! \brief Number of total measurement trials. */
+ /*! \brief The number of total measurement trials. */
int num_measure_trials;
- /*! \brief Stops early the tuning if no improvement after n measurements. */
+ /*! \brief Stops the tuning early if no improvement after n measurements. */
int early_stopping;
/*! \brief The number of programs to be measured at each search round. */
int num_measures_per_round;
int verbose;
/*! \brief ProgramBuilder which builds the program */
ProgramBuilder builder;
- /*! \brief ProgramRunner which runs the program and measure time costs */
+ /*! \brief ProgramRunner which runs the program and measures time costs */
ProgramRunner runner;
/*! \brief MeasureCallback functions to be called after each measure batch */
Optional<Array<MeasureCallback>> measure_callbacks;
public:
/*!
* \brief The constructor
- * \param num_measure_trials Number of total measurement trials.
- * \param early_stopping Stops early the tuning if no improvement after n measurements.
+ * \param num_measure_trials The number of total measurement trials.
+ * \param early_stopping Stops the tuning early if no improvement after n measurements.
* \param num_measures_per_round The number of programs to be measured at each search round.
* \param verbose Verbosity level. 0 for silent, 1 to output information during schedule
* search.
};
/*!
- * \brief Auto schedule search for a given compute declaration.
+ * \brief Run schedule search for a given compute declaration.
* \param task The search task of the compute declaration.
- * \param search_policy The search policy to be used for schedule search.
+ * \param search_policy The search policy to be used.
* \param tuning_options Tuning and measurement options.
- * \return A `te::schedule` and the a Array of `te::Tensor` to be used in `tvm.lower` or
+ * \return A `te::schedule` and the an Array of `te::Tensor` to be used in `tvm.lower` or
* `tvm.build`.
*/
TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task,
--- /dev/null
+/*r
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/auto_scheduler/compute_dag.h
+ * \brief The auto-scheduler's computational graph and related program analyses.
+ *
+ * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
+ * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
+ * total float operation count, consumer/producer relations of each operation stage, whether an
+ * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
+ * to make decisions during search process.
+ * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * `LoopState` with extra information got from TVM schedule ...).
+ */
+
+#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
+
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/te/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace auto_scheduler {
+
+/*! \brief Static analysis result for a ComputeDAG */
+class AccessAnalyzerNode : public Object {
+ public:
+ template <class T>
+ using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;
+
+ /*! \brief Map an operation to all operations it reads from.
+ * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
+ * The inner vector represents the indices of multi-dimensional access.*/
+ OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
+ /*! \brief Map an operation to all operations it is read by.
+ * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
+ * The inner vector represents the indices of multi-dimensional access.*/
+ OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
+ /*! \brief Store the number of common outer iterators for operation pairs that have
+ * read-write relations. */
+ OperationMap<OperationMap<int>> num_common_outer_iterators;
+ /*! \brief Store whether the operation is an op with only simple access.
+ * (e.g., injective, broadcast and elementwise ops without reduction) */
+ OperationMap<bool> is_simple_access;
+ /*! \brief Store whether the operation is strictly-inlineable
+ * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations)
+ */
+ OperationMap<bool> is_strict_inlineable;
+ /*! \brief Store whether the operation needs multi-level tiling
+ * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */
+ OperationMap<bool> needs_multi_level_tiling;
+ /*! \brief Store whether the operation is an output operation */
+ OperationMap<bool> is_output;
+ /*! \brief Store the topological order of operations */
+ Array<te::Operation> ops_topo_order;
+
+ static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object);
+};
+
+/*!
+ * \brief Managed reference to AccessAnalyzerNode.
+ * \sa AccessAnalyzerNode
+ */
+class AccessAnalyzer : public ObjectRef {
+ public:
+ explicit AccessAnalyzer(const Array<te::Tensor>& tensors);
+
+ /*!
+ * \brief Return whether this operation is an injective operation
+ * (e.g., injective, broadcast and elementwise ops without reduction)
+ * \param op The operation
+ */
+ TVM_DLL bool IsSimpleAccess(const te::Operation& op) const;
+
+ /*!
+ * \brief Return whether this operation is strictly inlinable
+ * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations)
+ * \param op The operation
+ */
+ TVM_DLL bool IsStrictInlineable(const te::Operation& op) const;
+
+ /*!
+ * \brief Return whether this operation needs multi-level tiling
+ * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d)
+ * \param op The operation
+ */
+ TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
+
+ /*!
+ * \brief Return whether this operation is an output op
+ * \param op The operation
+ */
+ TVM_DLL bool IsOutput(const te::Operation& op) const;
+
+ /*!
+ * \brief Get all consumers of on operation
+ * \param state The current loop state
+ * \param op The operation
+ * \return The set of consumers
+ * \note This function propagates the relation for inlined ops
+ */
+ TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetConsumers(
+ const State& state, const te::Operation& op) const;
+
+ /*!
+ * \brief Get all producers of on operation
+ * \param state The current loop state
+ * \param op The operation
+ * \return The set of producers
+ * \note This function propagates the relation for inlined ops
+ */
+ TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetProducers(
+ const State& state, const te::Operation& op) const;
+
+ /*!
+ * \brief Get all direct producers of on operation
+ * \param op The operation
+ * \return The set of direct producers
+ * \note This function DOES NOT propagate the relation for inlined ops
+ */
+ TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetDirectProducers(
+ const te::Operation& op) const;
+
+ /*!
+ * \brief Get the number of common outer iterators.
+ * \param op The operation
+ * \param target_op The target operation
+ * \note This function propagates the relation for chains with multiple ops.
+ */
+ TVM_DLL int GetNumCommonOuterIterator(const te::Operation& op,
+ const te::Operation& target_op) const;
+
+ /*!
+ * \brief Return whether two operations are elementwise-matched
+ * (e.g. conv2d and relu are elementwise matched)
+ * \note This function propagates the relation for chains with multiple ops.
+ */
+ TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const;
+
+ TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode);
+};
+
+/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */
+class ComputeDAGNode : public Object {
+ public:
+ /*!
+ * \brief Input and output tensors.
+ * This is used as the input of `tvm.lower` or `tvm.build`.
+ */
+ Array<te::Tensor> tensors;
+ /*! \brief All related operations in topo order. */
+ Array<te::Operation> ops;
+ /*! \brief The number of total float operations for this ComputeDAG. */
+ double flop_ct;
+ /*! \brief The initial state without any transform steps. */
+ State init_state;
+ /*! \brief The static read-write access analyzer */
+ AccessAnalyzer access_analyzer;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("tensors", &tensors);
+ v->Visit("ops", &ops);
+ v->Visit("flop_ct", &flop_ct);
+ v->Visit("init_state", &init_state);
+ }
+
+ static constexpr const char* _type_key = "auto_scheduler.ComputeDAG";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ComputeDAGNode.
+ * \sa ComputeDAGNode
+ */
+class ComputeDAG : public ObjectRef {
+ public:
+ /*! \brief The constructor.
+ * \param tensors `te::Tensor`s for a compute declaration.
+ */
+ TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);
+
+ /*!
+ * \brief Apply the history transform steps to get a TVM schedule.
+ * \param transform_steps Transform steps of a state.
+ * \param stages The list of stages after applying the steps.
+ * Pass a valid pointer if this information needs to be used outside this function.
+ * \param stage_to_axes The map that stores all axes for one stage.
+ * Pass a valid pointer if this information needs to be used outside this function.
+ * \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower`
+ * or `tvm.build`.
+ */
+ std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
+ const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
+ StageToAxesMap* stage_to_axes = nullptr) const;
+
+ /*!
+ * \brief Print transform steps as equivalent python schedule API.
+ * This can be used for debugging.
+ * \param transform_steps Transform steps of a state.
+ * \return The Python schedule code.
+ */
+ String PrintStepsAsPython(const Array<Step>& transform_steps) const;
+
+ /*!
+ * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound.
+ * The states can lose complete bound information after some transform steps (e.g., compute_at).
+ * We can call this function to infer and fill all the bound information.
+ * This function calls TVM InferBound pass internally to get the bound.
+ * The returned state of this function is guaranteed to have complete bound information.
+ * \param state The input state.
+ * \return The State with complete bound information
+ */
+ State InferBound(const State& state) const;
+
+ TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
+};
+
+} // namespace auto_scheduler
+} // namespace tvm
+
+#endif // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
#ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_
#define TVM_AUTO_SCHEDULER_LOOP_STATE_H_
+#include <dmlc/common.h>
+#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/container.h>
#include <functional>
#include <utility>
#include <vector>
-#include "transform_step.h"
-
namespace tvm {
namespace auto_scheduler {
*/
class AttachMapNode : public Object {
public:
+ struct IterKeyHash {
+ std::size_t operator()(const IterKey& k) const {
+ return ::dmlc::HashCombine(std::hash<int>()(k.first), std::hash<int>()(k.second));
+ }
+ };
+
/*! \brief A Map to store the mapping of stage to its attached iterator. */
std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
/*! \brief A Map to store the mapping of iterator to the stage attached to it. */
- std::unordered_map<IterKey, std::vector<StageKey>> iter_to_attached_stages;
+ std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages;
static constexpr const char* _type_key = "auto_scheduler.AttachMap";
TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object);
* this input.
* \return The iterator result after binded.
*/
- Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
+ TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
/*!
* \brief Schedule primitive corresponds to te.parallel.
* \param stage_id The index of the stage to be paralleled.
* \param it The iterator to be paralleled.
* \return The iterator result after parallel.
*/
- Iterator parallel(int stage_id, const Iterator& it);
+ TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
/*!
* \brief Schedule primitive corresponds to te.unroll.
* \param stage_id The index of the stage to be unrolled.
* skipped.
* \return The iterator result after unrolled.
*/
- Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
+ TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
/*!
* \brief Schedule primitive corresponds to te.vectorize.
* \param stage_id The index of the stage to be vectorized.
* \param it The iterator to be vectorized.
* \return The iterator result after vectorize.
*/
- Iterator vectorize(int stage_id, const Iterator& it);
+ TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
/*!
* \brief Schedule primitive corresponds to te.fuse.
* \param stage_id The index of the stage to be fused.
* \note If the iterators to be fused have stages attached at them(by compute_at), the fused
* result will become the new attach point.
*/
- Iterator fuse(int stage_id, const Array<Iterator>& iters);
+ TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
/*!
* \brief Schedule primitive corresponds to te.reorder.
* \param stage_id The index of the stage to be reordered.
* \param order The expected iterator order.
*/
- void reorder(int stage_id, const Array<Iterator>& order);
+ TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
/*!
* \brief Schedule primitive corresponds to te.split.
* \param stage_id The index of the stage to be split.
* \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
* most iterator of split results will become the new attach point.
*/
- Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
- bool inner_to_outer = true);
+ TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
+ const Array<Optional<Integer>>& lengths,
+ bool inner_to_outer = true);
/********** Step APIs working on multiple stages **********/
* bound for the newly created iterators.
* Call ComputeDAG::InferBound on the updated state to get the complete bound information.
*/
- void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
+ TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
/*!
* \brief Schedule primitive corresponds to te.compute_inline.
* \param stage_id The index of the stage to be reordered.
*/
- void compute_inline(int stage_id);
+ TVM_DLL void compute_inline(int stage_id);
/*!
* \brief Schedule primitive corresponds to te.compute_root.
* \param stage_id The index of the stage to be reordered.
* bound for the newly created iterators.
* Call ComputeDAG::InferBound on the updated state to get the complete bound information.
*/
- void compute_root(int stage_id);
+ TVM_DLL void compute_root(int stage_id);
TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
// Hash and equal function for State
namespace std {
-/*! \brief The hash function for auto_scheduler::State. */
-template <>
-struct hash<::tvm::auto_scheduler::State> {
- std::size_t operator()(const ::tvm::auto_scheduler::State& state) const {
- return tvm::runtime::ObjectHash()(state.ToStr());
- }
-};
-
/*!
* \brief The equal_to function for auto_scheduler::State.
- * We use the schedule result(its string format) of a state to check if two states are `euqal`.
- * Equal States: 1. the transform steps are totally the same; 2. even with different steps, two
- * states may still result in a same schedule. e.g. To split a axis with extent 512 to 3 parts
- * [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can get a same result
- * to split from outter to inner by factors [8, 16])
+ * This function checkes the equality by looking at the lowered string format of states.
+ * If two states with different transform history have the same lowered string format,
+ * they will be considered being equal.
*/
template <>
struct equal_to<::tvm::auto_scheduler::State> {
}
};
+/*! \brief The hash function for auto_scheduler::State. */
+template <>
+struct hash<::tvm::auto_scheduler::State> {
+ std::size_t operator()(const ::tvm::auto_scheduler::State& state) const {
+ return tvm::runtime::ObjectHash()(state.ToStr());
+ }
+};
+
} // namespace std
#endif // TVM_AUTO_SCHEDULER_LOOP_STATE_H_
* These functions are responsible for building the tvm module, uploading it to remote devices,
* recording the running time costs, and checking the correctness of the output.
*
- * We separate the measurement into two steps: build and run.
+ * The measurement is separated into two steps: build and run.
* A builder builds the executable binary files and a runner runs the binary files to get the
* measurement results. The flow of data structures is
*
* `ProgramBuilder` `ProgramRunner`
* `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
*
- * We implement these in python to utilize python's multiprocessing and error handling.
+ * The core functions is implemented in python to utilize python's multiprocessing
+ * and error handling (see also `python/tvm/auto_scheduler/measure.py`).
+ * This c++ file is just a wrapper for the python functions.
*/
#ifndef TVM_AUTO_SCHEDULER_MEASURE_H_
#define TVM_AUTO_SCHEDULER_MEASURE_H_
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/search_task.h>
+
#include <string>
#include <unordered_map>
#include <utility>
-#include "loop_state.h"
-#include "search_task.h"
-
namespace tvm {
namespace auto_scheduler {
public:
/*!
* \brief Callback function that will be called on measurement input/result pairs
- * after measurement.
+ * after each measurement batch.
* \param policy The current search policy.
* \param inputs An Array of MeasureInput.
* \param results An Array of MeasureResult.
/*! \brief ProgramBuilder that builds the programs */
class ProgramBuilderNode : public Object {
public:
- /*! \brief The number of tasks to run in parallel */
+ /*! \brief The number of build processes to run in parallel */
int n_parallel;
/*! \brief Timeout of a build */
int timeout;
* \brief The constructor.
* \param timeout The timeout limit (in second) for each build thread.
* This will be used in a wrapper of the multiprocessing.Process.join().
- * \param n_parallel Number of threads used to build in parallel.
- * \param build_func The name of registered build function.
+ * \param n_parallel The number of threads used to build in parallel.
+ * \param build_func The name of the registered build function.
*/
LocalBuilder(int timeout, int n_parallel, const String& build_func);
TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode);
};
-/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */
+/*! \brief LocalRunner that uses local CPU/GPU to measure the time cost of programs */
class LocalRunnerNode : public ProgramRunnerNode {
public:
Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
String key;
/*! \brief The host address of the RPC Tracker. */
String host;
- /*! \brief The port of RPC Tracker. */
+ /*! \brief The port of the RPC Tracker. */
int port;
/*! \brief The priority of this run request, larger is more prior. */
int priority;
/*! \brief The number of tasks run in parallel. */
int n_parallel;
- /*! \brief The number of times to run the generated code for taking average. */
Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
const Array<BuildResult>& build_results, int verbose) final;
class RPCRunner : public ProgramRunner {
public:
/*!
- * \brief The constructor.
+ * \brief The constructor. See the corresponding class in python/tvm/auto_scheduler/measure.py
+ * for more detailed parameter explaination.
* \param key The key of the device registered in the RPC tracker.
* \param host The host address of the RPC Tracker.
- * \param prot The port of RPC Tracker.
+ * \param port The port of RPC Tracker.
* \param priority The priority of this run request, larger is more prior.
* \param n_parallel The number of tasks run in parallel.
* \param timeout Timeout of a run.
/*!
* \brief Measurer that measures the time costs of tvm programs
- * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */
+ * This class combines ProgramBuilder and ProgramRunner and provides a simpler API */
class ProgramMeasurerNode : public Object {
public:
/*! \brief Measured programs counter. */
* \param callbacks MeasureCallback to be called after each measure batch.
* \param verbose Verbosity level. 0 for silent, 1 to output information during program
* measuring.
- * \param max_continous_error The number of max continuous error.
+ * \param max_continous_error The number of allowed maximum continuous error.
*/
ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
Optional<Array<MeasureCallback>> callbacks, int verbose,
*/
/*!
- * \file auto_scheduler/measure_record.h
- * \brief Json serialization format for dumping and loading tuning records.
+ * \file tvm/auto_scheduler/measure_record.h
+ * \brief Json serialization format for dumping and loading measurement records.
*/
#ifndef TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_
#define TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_
+#include <tvm/auto_scheduler/measure.h>
+
#include <fstream>
#include <string>
#include <utility>
-#include "measure.h"
-
namespace tvm {
namespace auto_scheduler {
/*! \brief Callback for logging the input and results of measurements to file */
class RecordToFileNode : public MeasureCallbackNode {
public:
- /*! \brief File name for this callback to write log to. */
+ /*! \brief The name of output file. */
String filename;
void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
public:
/*!
* \brief The constructor.
- * \param filename File name for this callback to write log.
+ * \param filename The name of output file
*/
explicit RecordToFile(String filename);
/*! \brief Log reader to load step logs from a file.*/
class RecordReaderNode : public Object {
public:
- /*! \brief File name for this reader to load log from. */
+ /*! \brief The name of input file. */
String filename;
/*! \brief The reading file stream. */
std::ifstream infile;
TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object);
private:
- /*! \brief A string object to store the next line. */
+ /*! \brief A string storing the current line. */
std::string cur_line_;
};
public:
/*!
* \brief The constructor.
- * \param filename File name for this callback to write log.
+ * \param filename The name of input file
*/
explicit RecordReader(String filename);
};
/*!
- * \brief Write measure records to an output stream.
+ * \brief Append measure records to an output stream.
* \param os A pointer to a output stream.
* \param inputs The MeasureInputs to be written.
* \param results The MeasureResults to be written.
/*!
* \brief Read one measure record from a string.
- * \param str The record string to be extract.
- * \param inp A pointer to a MeasureInputNode, this is used as output.
- * \param res A pointer to a MeasureResultNode, this is used as output.
- * \param log_version A pointer to a log version string.
+ * \param str The record string to be parsed.
+ * \param inp A pointer to a MeasureInputNode used to store the return value.
+ * \param res A pointer to a MeasureResultNode used to store the return value.
+ * \param log_version A pointer to a string used to store the log version.
*/
void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res,
std::string* log_version);
*/
/*!
- * \file auto_scheduler/search_policy/search_policy.h
+ * \file tvm/auto_scheduler/search_policy.h
* \brief The base class of search policies, including the abstract definition of search policy and
* other supporting data structures.
*
- * The basic schedule search process for TVM Auto-scheduler is design to be:
+ * The basic schedule search process for the auto-scheduler is design to be:
* `Program sampling` -> `Performance Tuning`.
*
* In `Program sampling`, we use some predefined precise or heuristic rules to generate several
*
* Candidate schedules are measured against the specific hardware target.
*
- * \note Adding a new search policy.
+ * \note How to add a new search policy.
* In design, there's no need for users to implement their own search policy, our formal search
* policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule
* mechanism will be provided to enable user-defined template search to serve the same functionality
* during the search process.
*/
-#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_
-#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_
+#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
+#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
+#include <tvm/auto_scheduler/search_task.h>
#include <tvm/node/node.h>
#include <unordered_set>
#include <vector>
-#include "../search_task.h"
-
namespace tvm {
namespace auto_scheduler {
/*!
* \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state
- * get during the search process.
- * \param task The SearchTask or workload key for the computation declaration
- * \param num_measure_trials Total schedules to be tried during this search.
- * \param early_stopping Early stop if no better schedule is found.
- * \param num_measures_per_round Max measure batch in one search round.
+ * found during the search.
+ * \param task The SearchTask for the computation declaration
+ * \param num_measure_trials The number of total measurement trials.
+ * \param early_stopping Stops the tuning early if no improvement after n measurements.
+ * \param num_measures_per_round The number of programs to be measured at each search round.
* \param verbose Verbose level. 0 for silent, 1 to output information during schedule
* search.
- * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside.
+ * \param measurer A ProgramMeasurer to build and measure programs
* \param pre_search_callbacks SearchCallback to be called before schedule search.
- * \return The best state get.
+ * \return The best state found.
*/
virtual State Search(SearchTask task, int num_measure_trials, int early_stopping,
int num_measures_per_round, int verbose, ProgramMeasurer measurer,
protected:
/*!
* \brief The set of already measured states.
- * During the schedule search process, we may generate `equal states` through different search
- * branches. (Equal States: 1. the transform steps are totally the same; 2. even with different
- * steps, two states may still result in a same schedule. e.g. To split a axis with extent 512
- * to 3 parts [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can
- * get a same result to split from outter to inner by factors [8, 16])
* We store the string format of a state for redundancy check. This is used to make sure a
* measured state will never be measured again.
*/
std::unordered_set<String> measured_states_set_;
- /*! \brief The array of already measured states. This can be used in evolutionary search. */
+ /*! \brief The array of already measured states.
+ * The good states can be used as the initial population in evolutionary search. */
std::vector<State> measured_states_vector_;
/*! \brief The throughputs of already measured states */
std::vector<float> measured_states_throughputs_;
} // namespace auto_scheduler
} // namespace tvm
-#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_
+#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
#ifndef TVM_AUTO_SCHEDULER_SEARCH_TASK_H_
#define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_
+#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/target/target.h>
-#include "compute_dag.h"
-
namespace tvm {
namespace auto_scheduler {
class HardwareParams;
-/*! \brief The parameters of target hardware used to guide the search process of SearchPolicy. */
+/*! \brief The parameters of target hardware used to guide the SearchPolicy. */
class HardwareParamsNode : public Object {
public:
/*! \brief The number of cores. */
/*!
* \file auto_scheduler/transform_step.h
- * \brief Transformation steps. For each schedule primitive, there is a corresponding transform
- * step.
+ * \brief Transformation steps. These steps are used to manipulate the LoopState.
+ * They are similar to the schedule primitives in te::Stage.
*
- * \note To add a new transform step:
+ * \note How to add a new transform step:
* Take fuse step for example:
* 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first
* construction function `FuseStep::FuseStep()` in `transform_steps.cc`.
#include <tvm/node/node.h>
#include <tvm/te/schedule.h>
-#include "utils.h"
-
namespace tvm {
namespace auto_scheduler {
* \param step The step to be applied to State.
* \param state A mutable pointer to State.
* \param dag The original ComputeDAG of this state.
- * \return The iterator result after annotate.
*/
void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes);
-/********** Primitives working on single stage **********/
+/********** Steps working on single stage **********/
/*!
* \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding.
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
-/********** Primitives working on multiple stages **********/
+/********** Steps working on multiple stages **********/
/*! \brief Compute at step that corresponds to te::Stage::compute_at */
class ComputeAtStepNode : public StepNode {
@tvm._ffi.register_object("auto_scheduler.SearchTask")
class SearchTask(Object):
- """ The computation information and hardware parameters for a specific schedule search task.
+ """ The computation information and hardware parameters for a schedule search task.
Parameters
----------
def auto_schedule(task, search_policy='default', tuning_options=None):
""" Do auto scheduling for a computation declaration.
- The task parameter can be a `string` as workload_key, or directly
- passing a `SearchTask` as input.
-
Parameters
----------
task : SearchTask
Returns
-------
- workload_key : Str
+ workload_key : str
The workload key of the function.
"""
global WORKLOAD_FUNC_REGISTRY
* schedule after search process.
*/
-#include "auto_schedule.h"
-
+#include <tvm/auto_scheduler/auto_schedule.h>
#include <tvm/runtime/registry.h>
namespace tvm {
* \brief Compute declaration graph and its related analysis tools.
*/
-#include "compute_dag.h"
-
+#include <tvm/auto_scheduler/compute_dag.h>
+#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <algorithm>
#include <unordered_set>
#include <vector>
-#include "loop_state.h"
+#include "../arith/pattern_match.h"
#include "utils.h"
namespace tvm {
using namespace tvm::tir;
+template <class T>
+using OperationMap = AccessAnalyzerNode::OperationMap<T>;
+using OperationSet = std::unordered_set<te::Operation, ObjectHash, ObjectEqual>;
+
TVM_REGISTER_NODE_TYPE(ComputeDAGNode);
// Topo-sort ops from tensors according to their read-write relations.
return ops;
}
-// Estimate number of float operations in an expression
+// Extract all tensor accesses in an expr
+class ReadAccessExtractor : public StmtExprVisitor {
+ public:
+ void Extract(PrimExpr expr) { this->VisitExpr(expr); }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::if_then_else())) {
+ has_branch = true;
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+ read_access[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
+ op->indices.end());
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ has_branch = true;
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ // All read accesses to all operations
+ // The innermost vector stores mulit-dimentional indices.
+ // The middle vector stores possible multiple accesses
+ OperationMap<std::vector<std::vector<PrimExpr>>> read_access;
+ // Whether this expression has branch
+ bool has_branch{false};
+};
+
+// Returns whether the expr equals to the var with an optional const shift
+bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
+ arith::PVar<PrimExpr> x;
+ arith::PVar<IntImm> c;
+
+ if (((x + c).Match(expr) || (x - c).Match(expr) || (c + x).Match(expr) || x.Match(expr)) &&
+ x.Eval().same_as(var)) {
+ return true;
+ }
+ return false;
+}
+
+// Return whether the access to an operation is a simple access
+// (i.e. all index is just a variable with an optional constant shift)
+// For example, A[i][j], A[i+1][j] are simple accesses but A[i][j+i] is not.
+bool IsSimpleAccess(const te::Operation& op, const std::vector<PrimExpr>& indices,
+ bool* axis_missing, bool* axis_duplicated, bool* same_order) {
+ auto cop = op.as<te::ComputeOpNode>();
+ if (cop == nullptr) {
+ return false;
+ }
+
+ std::vector<int> index_to_var_idx;
+ std::vector<int> var_idx_ct(cop->axis.size(), 0);
+
+ for (const auto& expr : indices) {
+ if (!is_const_int(expr)) {
+ bool found = false;
+ for (size_t i = 0; i < cop->axis.size(); ++i) {
+ if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
+ index_to_var_idx.push_back(i);
+ var_idx_ct[i]++;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return false;
+ }
+ }
+ }
+
+ *axis_missing = false; // Some axes are missing
+ *axis_duplicated = false; // Some axes appear more than once
+ *same_order = true; // The axis order is the same as op->axis
+ for (int ct : var_idx_ct) {
+ if (ct == 0) {
+ *axis_missing = true;
+ } else if (ct > 1) {
+ *axis_duplicated = true;
+ }
+ }
+ for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
+ if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
+ *same_order = false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+// Gather all VarNodes in an expr
+void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
+ PostOrderVisit(expr, [&vars](const ObjectRef& node) {
+ if (const VarNode* op = node.as<VarNode>()) {
+ vars->insert(op);
+ }
+ });
+}
+
+// Check whether an expr has expensive operations (e.g. exp)
+bool HasExpensiveOp(const PrimExpr& expr) {
+ bool found = false;
+ PostOrderVisit(expr, [&found](const ObjectRef& node) {
+ if (const CallNode* op = node.as<CallNode>()) {
+ if (op->op.as<OpNode>()->name == "tir.exp") {
+ found = true;
+ }
+ }
+ });
+ return found;
+}
+
+AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
+ auto node = make_object<AccessAnalyzerNode>();
+ OperationMap<bool> has_branch;
+
+ // Get all ops in topological order
+ node->ops_topo_order = TopoSortOps(tensors);
+
+ arith::Analyzer analyzer;
+
+ // Build read & write access map
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
+ } else if (auto cop = op.as<te::ComputeOpNode>()) {
+ ReadAccessExtractor extractor;
+ for (const auto& exp : cop->body) {
+ extractor.Extract(exp);
+ }
+
+ // read_by and read_from map
+ for (const auto& iter : extractor.read_access) {
+ std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
+ accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
+ }
+
+ node->read_from[op] = std::move(extractor.read_access);
+ has_branch[op] = extractor.has_branch;
+
+ // compute number of common outer iterators
+ for (const auto& pair : node->read_from[op]) {
+ const te::Operation& producer = pair.first;
+ const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+ const Array<PrimExpr>& output_shape = op->output_shape(0);
+ const Array<PrimExpr>& producer_shape = producer->output_shape(0);
+
+ int n_common;
+ for (n_common = 0;
+ n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));
+ n_common++) {
+ if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) {
+ break;
+ }
+
+ bool injective = true;
+ for (const auto& access : access_list) {
+ if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) {
+ injective = false;
+ break;
+ }
+ }
+
+ if (!injective) {
+ break;
+ }
+ }
+
+ node->num_common_outer_iterators[op][producer] = n_common;
+ node->num_common_outer_iterators[producer][op] = n_common;
+ }
+ } else {
+ LOG(FATAL) << "Invalid op: " << op;
+ }
+ }
+
+ // Do some static analysis on ComputeOps
+ for (const auto& op : node->ops_topo_order) {
+ if (op->IsInstance<te::PlaceholderOpNode>()) {
+ node->is_simple_access[op] = true;
+ node->needs_multi_level_tiling[op] = false;
+ node->is_strict_inlineable[op] = false;
+ node->is_output[op] = false;
+ } else if (auto cop = op.as<te::ComputeOpNode>()) {
+ // check whether this op is element-wise and strict-inlineable
+ bool is_simple_access = true;
+ bool is_strict_inlineable = true;
+
+ bool axis_missing, axis_duplicated, same_order;
+ for (const auto& pair : node->read_from[op]) {
+ const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+ for (const auto& access : access_list) {
+ if (!auto_scheduler::IsSimpleAccess(op, access, &axis_missing, &axis_duplicated,
+ &same_order)) {
+ is_simple_access = false;
+ is_strict_inlineable = false;
+ break;
+ }
+ if (!same_order || axis_duplicated) {
+ // do not strictly inline transpose
+ is_strict_inlineable = false;
+ }
+ }
+ if (!is_simple_access) {
+ break;
+ }
+ }
+
+ // don't strictly inline expensive op (e.g. exp)
+ bool has_expensive_op = false;
+ for (const auto& expr : cop->body) {
+ has_expensive_op |= HasExpensiveOp(expr);
+ }
+ if (has_expensive_op || has_branch[op]) {
+ is_strict_inlineable = false;
+ }
+
+ node->is_simple_access[op] = is_simple_access;
+ node->is_strict_inlineable[op] = is_strict_inlineable;
+
+ // check whether the op needs multi-level tiling
+ bool needs_multi_level_tiling = false;
+ int n_missing = 0;
+
+ for (const auto& pair : node->read_from[op]) {
+ const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
+ std::unordered_set<const VarNode*> vars;
+ for (const std::vector<PrimExpr>& access : access_list) {
+ for (const PrimExpr& expr : access) {
+ GatherVars(expr, &vars);
+ }
+ }
+
+ for (const auto& axis : cop->axis) {
+ if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) {
+ n_missing++;
+ break;
+ }
+ }
+
+ if (n_missing >= 2 || (n_missing >= 1 && !cop->reduce_axis.empty())) {
+ needs_multi_level_tiling = true;
+ break;
+ }
+ }
+
+ node->needs_multi_level_tiling[op] = needs_multi_level_tiling;
+
+ // check whether the op is output
+ node->is_output[op] = node->read_by[op].empty();
+ } else {
+ LOG(FATAL) << "Invalid op" << op;
+ }
+ }
+
+ data_ = std::move(node);
+}
+
+bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation& op) const {
+ return operator->()->needs_multi_level_tiling.at(op);
+}
+
+bool AccessAnalyzer::IsOutput(const te::Operation& op) const {
+ return operator->()->is_output.at(op);
+}
+
+bool AccessAnalyzer::IsSimpleAccess(const te::Operation& op) const {
+ return operator->()->is_simple_access.at(op);
+}
+
+bool AccessAnalyzer::IsStrictInlineable(const te::Operation& op) const {
+ return operator->()->is_strict_inlineable.at(op);
+}
+
+OperationSet AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op) const {
+ OperationSet inlined_ops;
+ for (const auto& stage : state->stages) {
+ if (stage->compute_at == ComputeAtKind::kInlined) {
+ inlined_ops.insert(stage->op);
+ }
+ }
+
+ OperationSet consumers;
+ std::function<void(const te::Operation&)> collect;
+ collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) {
+ for (const auto& iter : operator->()->read_by.at(op)) {
+ if (inlined_ops.count(iter.first)) {
+ collect(iter.first);
+ } else {
+ consumers.insert(iter.first);
+ }
+ }
+ };
+
+ collect(op);
+ return consumers;
+}
+
+OperationSet AccessAnalyzer::GetDirectProducers(const te::Operation& op) const {
+ OperationSet producers;
+ for (const auto& iter : operator->()->read_from.at(op)) {
+ producers.insert(iter.first);
+ }
+ return producers;
+}
+
+OperationSet AccessAnalyzer::GetProducers(const State& state, const te::Operation& op) const {
+ OperationSet inlined_ops;
+ for (const auto& stage : state->stages) {
+ if (stage->compute_at == ComputeAtKind::kInlined) {
+ inlined_ops.insert(stage->op);
+ }
+ }
+
+ OperationSet producers;
+ std::function<void(const te::Operation&)> collect;
+ collect = [this, &collect, &inlined_ops, &producers](const te::Operation& op) {
+ for (const auto& iter : operator->()->read_from.at(op)) {
+ if (inlined_ops.count(iter.first)) {
+ collect(iter.first);
+ } else {
+ producers.insert(iter.first);
+ }
+ }
+ };
+
+ collect(op);
+ return producers;
+}
+
+int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op,
+ const te::Operation& target_op) const {
+ int ret = INT32_MAX;
+ bool meet = false;
+
+ std::function<void(const te::Operation&, int)> traverse;
+ traverse = [this, &traverse, &target_op, &ret, &meet](const te::Operation& cur_op, int cur_num) {
+ if (cur_op == target_op) {
+ ret = std::min(ret, cur_num);
+ meet = true;
+ return;
+ }
+
+ for (const auto& iter : operator->()->read_by.at(cur_op)) {
+ traverse(
+ iter.first,
+ std::min(cur_num, operator->()->num_common_outer_iterators.at(cur_op).at(iter.first)));
+ }
+ };
+
+ traverse(op, op->output_shape(0).size());
+ return meet ? ret : 0;
+}
+
+bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op,
+ const te::Operation& target_op) const {
+ te::Operation cur_op = op;
+ while (cur_op != target_op) {
+ const AccessAnalyzerNode::OperationMap<std::vector<std::vector<PrimExpr>>>& map =
+ operator->()->read_by.at(cur_op);
+
+ if (map.size() != 1) {
+ return false;
+ }
+ te::Operation next_op = map.begin()->first;
+
+ // Check condition 1: They have the same output size
+ auto p_cur = cur_op.as<te::ComputeOpNode>();
+ auto p_next = next_op.as<te::ComputeOpNode>();
+ if (p_cur == nullptr || p_next == nullptr) {
+ return false;
+ }
+
+ Array<PrimExpr> output_shape = p_cur->output_shape(0);
+ for (int i = 1; i < p_cur->num_outputs(); ++i) {
+ if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) {
+ return false;
+ }
+ }
+ for (int i = 0; i < p_next->num_outputs(); ++i) {
+ if (!IntArrayEqual(p_next->output_shape(i), output_shape)) {
+ return false;
+ }
+ }
+
+ // Check condition 2: The read is elementwise
+ const std::vector<std::vector<PrimExpr>> reads = map.begin()->second;
+ bool is_simple_access, axis_missing, axis_duplicated, same_order;
+ for (const auto& read : reads) {
+ is_simple_access = auto_scheduler::IsSimpleAccess(next_op, read, &axis_missing,
+ &axis_duplicated, &same_order);
+ if (!is_simple_access || axis_missing || axis_duplicated || !same_order) {
+ return false;
+ }
+ }
+
+ cur_op = std::move(next_op);
+ }
+ return true;
+}
+
+// Estimate the number of float operations in an expression
class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
public:
double EstimateFlop(const Array<te::Operation>& ops) {
fail_ = true;
break;
}
+ cur_type_code_ = pop->output_dtype(0).code();
double op_per_element = 0;
for (const auto& x : pop->body) {
op_per_element += VisitExpr(x);
std::max(VisitExpr(op->true_value), VisitExpr(op->false_value));
}
-#define VisitBinary(Node) \
- double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); }
-#define VisitUnary(Node) \
- double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a); }
+#define VisitBinary(Node) \
+ double VisitExpr_(const Node* op) final { \
+ double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \
+ return base + VisitExpr(op->a) + VisitExpr(op->b); \
+ }
+
+#define VisitUnary(Node) \
+ double VisitExpr_(const Node* op) final { \
+ double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \
+ return base + VisitExpr(op->a); \
+ }
VisitBinary(AddNode);
VisitBinary(SubNode);
private:
bool fail_{false};
+ int cur_type_code_;
};
ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
auto node = make_object<ComputeDAGNode>();
node->tensors = std::move(tensors);
- node->ops = TopoSortOps(node->tensors);
+ node->access_analyzer = AccessAnalyzer(node->tensors);
+ node->ops = node->access_analyzer->ops_topo_order;
node->flop_ct = FlopEstimator().EstimateFlop(node->ops);
node->init_state = State(node->ops);
data_ = std::move(node);
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file auto_scheduler/compute_dag.h
- * \brief The TVM Auto-scheduler computational graph and related program analyses.
- *
- * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
- * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
- * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
- * total float operation count, consumer/producer relations of each operation stage, whether an
- * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
- * to make decisions during search process.
- * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
- * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
- * `LoopState` with extra information got from TVM schedule ...).
- */
-
-#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
-#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
-
-#include <tvm/te/schedule.h>
-
-#include <utility>
-
-#include "loop_state.h"
-
-namespace tvm {
-namespace auto_scheduler {
-
-/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */
-class ComputeDAGNode : public Object {
- public:
- /*!
- * \brief Input and output tensors.
- * This is used as the input of `tvm.lower` or `tvm.build`.
- */
- Array<te::Tensor> tensors;
- /*! \brief All related operations in topo order. */
- Array<te::Operation> ops;
- /*! \brief Number of total float operations for this ComputeDAG. */
- double flop_ct;
- /*! \brief The initial state without any transform steps. */
- State init_state;
- // TODO(merrymercy): Add more analyses later.
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("tensors", &tensors);
- v->Visit("ops", &ops);
- v->Visit("flop_ct", &flop_ct);
- v->Visit("init_state", &init_state);
- }
-
- static constexpr const char* _type_key = "auto_scheduler.ComputeDAG";
- TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object);
-};
-
-/*!
- * \brief Managed reference to ComputeDAGNode.
- * \sa ComputeDAGNode
- */
-class ComputeDAG : public ObjectRef {
- public:
- /*! \brief The constructor.
- * \param tensors `te::Tensor`s for a compute declaration.
- */
- explicit ComputeDAG(Array<te::Tensor> tensors);
-
- /*!
- * \brief Apply the history transform steps from a State to get a TVM schedule.
- * \param transform_steps Transform steps of a state.
- * \param stages A pointer to a `te::Stage` Array, default to be nullptr.
- * Pass a valid pointer if these information needs to be used outside this function.
- * \param stage_to_axes A pointer to a StageToAxesMap, default to be nullptr.
- * Pass a valid pointer if these information needs to be used outside this function.
- * \return A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
- */
- std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
- const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
- StageToAxesMap* stage_to_axes = nullptr) const;
-
- /*!
- * \brief Print transform steps as equivalent python schedule API.
- * This can be used for debugging.
- * \param transform_steps Transform steps of a state.
- * \return The Python schedule code.
- */
- String PrintStepsAsPython(const Array<Step>& transform_steps) const;
-
- /*!
- * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound.
- * The states can lose complete bound information after some transform steps (e.g., compute_at).
- * We can call this function to infer and fill all the bound information.
- * This function calls TVM InferBound pass internally to get the bound.
- * The returned state of this function is guaranteed to have complete iterator extent information.
- * \param state The state to.
- * \return The State after inferbound.
- */
- State InferBound(const State& state) const;
-
- TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
- TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
-};
-
-} // namespace auto_scheduler
-} // namespace tvm
-
-#endif // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
* see auto_scheduler/loop_state.h for more explanation.
*/
-#include "loop_state.h"
-
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <utility>
-#include "transform_step.h"
#include "utils.h"
namespace tvm {
* \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs.
*/
-#include "measure.h"
-
+#include <tvm/auto_scheduler/measure.h>
#include <tvm/runtime/registry.h>
#include <algorithm>
* \brief Json serialization format for dumping and loading tuning records.
*/
-#include "measure_record.h"
-
#include <dmlc/json.h>
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/measure_record.h>
+#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/registry.h>
#include <fstream>
#include <utility>
#include <vector>
-#include "loop_state.h"
-#include "transform_step.h"
#include "utils.h"
// Json serialization handler for MeasureInput, MeasureResult
#include "empty_policy.h"
+#include <tvm/auto_scheduler/measure.h>
#include <tvm/runtime/registry.h>
-#include "../measure.h"
-
namespace tvm {
namespace auto_scheduler {
#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_
#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_
-#include "../loop_state.h"
-#include "search_policy.h"
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/search_policy.h>
namespace tvm {
namespace auto_scheduler {
* \brief The base class of search policies.
*/
-#include "search_policy.h"
-
+#include <tvm/auto_scheduler/search_policy.h>
#include <tvm/runtime/registry.h>
namespace tvm {
* \brief Meta information and hardware parameters for a search task.
*/
-#include "search_task.h"
-
+#include <tvm/auto_scheduler/search_task.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>
/*!
* \file auto_scheduler/transform_step.cc
- * \brief Transformation steps. For each schedule primitive, there is a corresponding transform
- * step.
+ * \brief Transformation steps. These steps are used to manipulate the LoopState.
+ * They are similar to the schedule primitives in te::Stage.
*/
-#include "transform_step.h"
-
+#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <utility>
#include <vector>
-#include "loop_state.h"
#include "utils.h"
namespace tvm {
}
void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
+ // We need this runtime dispatcher because different steps have different function signatures
if (auto ps = step.as<AnnotationStepNode>()) {
ps->ApplyToState(state);
} else if (auto ps = step.as<FuseStepNode>()) {
void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) {
+ // We need this runtime dispatcher because different steps have different function signatures
if (auto ps = step.as<AnnotationStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) {
+ // We need this runtime dispatcher because different steps have different function signatures
if (auto ps = step.as<AnnotationStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
return "";
}
-/********** Primitives working on single stage **********/
+/********** Steps working on single stage **********/
/********** Annotation **********/
AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) {
return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
}
-/********** Primitives working on multiple stages **********/
+/********** Steps working on multiple stages **********/
/********** Compute At **********/
ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) {
return out;
}
+/*! \brief Return whether two int arrays are elementwise-equal */
+inline bool IntArrayEqual(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
+ if (arr1.size() != arr2.size()) {
+ return false;
+ }
+
+ for (size_t i = 0; i < arr1.size(); ++i) {
+ auto int1 = arr1[i].as<IntImmNode>();
+ auto int2 = arr2[i].as<IntImmNode>();
+ CHECK(int1 != nullptr);
+ CHECK(int2 != nullptr);
+ if (int1->value != int2->value) {
+ return false;
+ }
+ }
+ return true;
+}
+
/********** Utilities for TVM Containers / ByteArray **********/
/*! \brief Compute mean of a FloatImm array */
inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <dmlc/logging.h>
+#include <gtest/gtest.h>
+#include <topi/nn.h>
+#include <tvm/auto_scheduler/compute_dag.h>
+#include <tvm/runtime/container.h>
+#include <tvm/te/operation.h>
+
+#include <unordered_set>
+
+// Compute declaration for test
+tvm::Array<tvm::te::Tensor> conv2d_nchw_bn_relu_func(int N, int H, int W, int CI, int CO,
+ int kernel_size, int strides, int padding,
+ int dilation = 1) {
+ using namespace tvm;
+ using namespace tvm::te;
+
+ Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data");
+ Tensor kernel = placeholder({CO, CI, kernel_size, kernel_size}, DataType::Float(32), "Kernel");
+ Tensor bias = placeholder({CO, 1, 1}, DataType::Float(32), "Bias");
+ Tensor bn_scale = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_scale");
+ Tensor bn_offset = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_offset");
+
+ int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1;
+ int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1;
+
+ const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, strides, strides);
+ CHECK(conv->shape[2].as<IntImmNode>()->value == OH);
+ CHECK(conv->shape[3].as<IntImmNode>()->value == OW);
+
+ const auto& bias_add = compute(
+ {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { return conv[i][j][k][l] + bias[j][0][0]; },
+ "Bias_add");
+ const auto& bn_mul = compute(
+ {N, CO, OH, OW},
+ [&](Var i, Var j, Var k, Var l) { return bias_add[i][j][k][l] * bn_scale[j][0][0]; },
+ "Bn_mul");
+ const auto& bn_add = compute(
+ {N, CO, OH, OW},
+ [&](Var i, Var j, Var k, Var l) { return bn_mul[i][j][k][l] + bn_offset[j][0][0]; },
+ "Bn_add");
+ const auto& out = topi::relu<float>(bn_add);
+
+ return {data, kernel, bias, bn_scale, bn_offset, out};
+}
+
+using namespace tvm::auto_scheduler;
+
+// Test Access Analyzer
+TEST(ComputeDAG, AccessAnalyzer) {
+ const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3);
+ const auto& dag = tvm::auto_scheduler::ComputeDAG(tensors);
+ State s0 = dag->init_state;
+
+ int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5;
+ int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10;
+
+ std::set<int> needs_multi_level_tiling = {conv};
+ for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) {
+ if (needs_multi_level_tiling.count(stage_id)) {
+ CHECK(dag->access_analyzer.NeedsMultiLevelTiling(dag->ops[stage_id]));
+ } else {
+ CHECK(!dag->access_analyzer.NeedsMultiLevelTiling(dag->ops[stage_id]));
+ }
+ }
+
+ std::set<int> is_simple_access = {data, padding, kernel, bias, bias_add,
+ bn_scale, bn_mul, bn_offset, bn_add, relu};
+ for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) {
+ if (is_simple_access.count(stage_id)) {
+ CHECK(dag->access_analyzer.IsSimpleAccess(dag->ops[stage_id]));
+ } else {
+ CHECK(!dag->access_analyzer.IsSimpleAccess(dag->ops[stage_id]));
+ }
+ }
+
+ std::set<int> is_strictly_inlinable = {bias_add, bn_mul, bn_add, relu};
+ for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) {
+ if (is_strictly_inlinable.count(stage_id)) {
+ CHECK(dag->access_analyzer.IsStrictInlineable(dag->ops[stage_id]));
+ } else {
+ CHECK(!dag->access_analyzer.IsStrictInlineable(dag->ops[stage_id]));
+ }
+ }
+
+ std::set<int> is_output = {relu};
+ for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) {
+ if (is_output.count(stage_id)) {
+ CHECK(dag->access_analyzer.IsOutput(dag->ops[stage_id]));
+ } else {
+ CHECK(!dag->access_analyzer.IsOutput(dag->ops[stage_id]));
+ }
+ }
+
+ CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[conv], dag->ops[bias_add]), 4);
+ CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[conv], dag->ops[relu]), 4);
+ CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[data], dag->ops[relu]), 1);
+
+ CHECK(dag->access_analyzer.ElementWiseMatch(dag->ops[conv], dag->ops[bias_add]));
+ CHECK(dag->access_analyzer.ElementWiseMatch(dag->ops[conv], dag->ops[relu]));
+ CHECK(!dag->access_analyzer.ElementWiseMatch(dag->ops[data], dag->ops[padding]));
+
+ std::unordered_set<tvm::te::Operation, tvm::ObjectHash, tvm::ObjectEqual> op_set;
+ {
+ std::vector<std::pair<int, int>> consumer_list = {
+ {data, padding}, {padding, conv}, {kernel, conv}, {conv, bias_add},
+ {bias, bias_add}, {bias_add, bn_mul}, {bn_scale, bn_mul}, {bn_mul, bn_add},
+ {bn_offset, bn_add}, {bn_add, relu}};
+ for (const auto& pair : consumer_list) {
+ op_set = dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op);
+ CHECK_EQ(op_set.size(), 1);
+ CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op);
+ }
+ std::vector<std::pair<int, std::vector<int>>> producer_list = {{padding, {data}},
+ {conv, {padding, kernel}},
+ {bias_add, {conv, bias}},
+ {bn_mul, {bias_add, bn_scale}},
+ {bn_add, {bn_mul, bn_offset}},
+ {relu, {bn_add}}};
+ for (const auto& pair : producer_list) {
+ op_set = dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op);
+ CHECK_EQ(op_set.size(), pair.second.size());
+ for (const auto& target : pair.second) {
+ CHECK(op_set.count(s0->stages[target]->op));
+ }
+ }
+ }
+
+ s0.compute_inline(bn_add);
+ s0.compute_inline(bn_mul);
+ s0.compute_inline(bias_add);
+ s0.compute_inline(padding);
+ {
+ std::vector<std::pair<int, int>> consumer_list = {{data, conv}, {kernel, conv}, {conv, relu}};
+ for (const auto& pair : consumer_list) {
+ op_set = dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op);
+ CHECK_EQ(op_set.size(), 1);
+ CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op);
+ }
+ std::vector<std::pair<int, std::vector<int>>> producer_list = {{padding, {data}},
+ {conv, {padding, kernel}},
+ {bias_add, {conv, bias}},
+ {bn_mul, {bias_add, bn_scale}},
+ {bn_add, {bn_mul, bn_offset}},
+ {relu, {bn_add}}};
+ for (const auto& pair : producer_list) {
+ op_set = dag->access_analyzer.GetDirectProducers(s0->stages[pair.first]->op);
+ CHECK_EQ(op_set.size(), pair.second.size());
+ for (const auto& target : pair.second) {
+ CHECK(op_set.count(s0->stages[target]->op));
+ }
+ }
+ }
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ testing::FLAGS_gtest_death_test_style = "threadsafe";
+ return RUN_ALL_TESTS();
+}
"""Test ComputeDAG (replay, infer bound)"""
-import tvm
+import tvm, topi
from tvm import auto_scheduler, te
-from test_auto_scheduler_common import get_tiled_matmul
+from test_auto_scheduler_common import get_tiled_matmul, matmul_auto_scheduler_test
def test_apply_steps():
def test_estimate_flop():
- dag, s = get_tiled_matmul()
- assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5
+ N = 512
+ A, B, C = matmul_auto_scheduler_test(N, N, N)
+ dag = auto_scheduler.ComputeDAG([A, B, C])
+ assert abs(dag.flop_ct - 2 * N ** 3) < 0.5
+
+ D = topi.nn.relu(C)
+ dag = auto_scheduler.ComputeDAG([A, B, D])
+ assert abs(dag.flop_ct - 2 * N ** 3 - N * N) < 0.5
+
+ # should not count the comparison operations in padding
+ D = topi.nn.pad(C, [1, 1])
+ dag = auto_scheduler.ComputeDAG([A, B, D])
+ assert abs(dag.flop_ct - 2 * N ** 3) < 0.5
if __name__ == "__main__":