* [REFACTOR] introduce top - Tensor Operation DSL.
Historically we put Tensor, Schedule and compute under the root tvm namespace.
This is no longer a good idea as the project's scope grows larger
than the tensor operation DSL.
This PR introduces top -- a namespace for tensor operational
DSL concepts such as schedule, tensor, compute.
We moved the related files to the new top subfolder.
* Move relevant files into include/tvm/top and src/top
assign_source_group("Include" ${GROUP_INCLUDE})
# Source file lists
-file(GLOB COMPILER_SRCS
+file(GLOB_RECURSE COMPILER_SRCS
src/node/*.cc
src/ir/*.cc
src/target/*.cc
- src/api/*.cc
src/arith/*.cc
+ src/top/*.cc
+ src/api/*.cc
src/autotvm/*.cc
- src/codegen/*.cc
src/lang/*.cc
src/pass/*.cc
- src/op/*.cc
- src/schedule/*.cc
)
+file(GLOB CODEGEN_SRCS
+ src/codegen/*.cc
+ )
+
+list(APPEND COMPILER_SRCS ${CODEGEN_SRCS})
+
file(GLOB_RECURSE RELAY_OP_SRCS
src/relay/op/*.cc
)
namespace tvm {
// forward delcare Tensor
+namespace top {
class Tensor;
+}
namespace arith {
/*!
* \param consider_provides If provides (write) are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
-Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
+Domain DomainTouched(Stmt body,
+ const top::Tensor &tensor,
+ bool consider_calls,
+ bool consider_provides);
} // namespace arith
} // namespace tvm
#include <tvm/target/target.h>
#include <tvm/support/with.h>
+#include <tvm/top/schedule_pass.h>
+
#include <string>
#include <vector>
#include <utility>
#include <unordered_set>
#include "runtime/packed_func.h"
-#include "schedule_pass.h"
+
#include "lowered_func.h"
namespace tvm {
* \param config The build configuration.
* \return The lowered function.
*/
-TVM_DLL Array<LoweredFunc> lower(Schedule sch,
- const Array<Tensor>& args,
+TVM_DLL Array<LoweredFunc> lower(top::Schedule sch,
+ const Array<top::Tensor>& args,
const std::string& name,
- const std::unordered_map<Tensor, Buffer>& binds,
+ const std::unordered_map<top::Tensor, Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
+#include <tvm/top/schedule.h>
+
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <string>
#include "expr.h"
#include "buffer.h"
-#include "schedule.h"
#include "lowered_func.h"
namespace tvm {
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
- Map<Tensor, Buffer> extern_buffer,
+ Map<top::Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
- Schedule schedule,
- Map<Tensor, Buffer> extern_buffer);
+ top::Schedule schedule,
+ Map<top::Tensor, Buffer> extern_buffer);
/*!
* \brief Verify if there is any argument bound to compact buffer.
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_
+#include <tvm/top/tensor.h>
+
#include <string>
#include "expr.h"
-#include "tensor.h"
#include "tvm/node/container.h"
namespace tvm {
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/operation.h
- * \brief Operation node can generate one or multiple Tensors
- */
-#ifndef TVM_OPERATION_H_
-#define TVM_OPERATION_H_
-
-#include <tvm/arith/analyzer.h>
-
-#include <string>
-#include <vector>
-#include <unordered_map>
-
-#include "expr.h"
-#include "expr_operator.h"
-#include "tensor.h"
-#include "schedule.h"
-#include "buffer.h"
-
-namespace tvm {
-
-using arith::IntSet;
-
-/*!
- * \brief Temporary data structure to store union
- * of bounds of each axis of Tensor.
- */
-struct TensorDom {
- // constructor
- explicit TensorDom(int ndim)
- : data(ndim) {}
- /*! \brief The domain data */
- std::vector<std::vector<IntSet> > data;
-};
-
-/*!
- * \brief Base class of all operation nodes
- */
-class OperationNode : public ir::FunctionBaseNode {
- public:
- /*! \brief optional name of the operation */
- std::string name;
- /*! \brief optional tag of the operation */
- std::string tag;
- /*! \brief additional attributes of the operation*/
- Map<std::string, ObjectRef> attrs;
- /*! \return name of the operation */
- const std::string& func_name() const final {
- return name;
- }
- /*!
- * \return The list of iteration variable at root
- * \note root_iter_vars decides the shape of the outputs.
- */
- virtual Array<IterVar> root_iter_vars() const = 0;
- /*!
- * \brief Get data type. i-th output tensor.
- * \param i The output index.
- * \return type of i-th output.
- */
- virtual DataType output_dtype(size_t i) const = 0;
- /*!
- * \brief Get shape of i-th output tensor.
- * \param i The output index.
- * \return shape of i-th output.
- */
- virtual Array<PrimExpr> output_shape(size_t i) const = 0;
- /*!
- * \brief List all the input Tensors.
- * \return List of input tensors.
- */
- virtual Array<Tensor> InputTensors() const = 0;
- /*!
- * \brief Replace the input of the operation by pattern specified by rmap.
- *
- * \param self The reference to self.
- * \param rmap The replacement map.
- * \return self if nothing is replaced, otherwise return replaced op.
- */
- virtual Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
- /*!
- * \brief Propagate the bounds to inputs
- * \param self The reference to self.
- * \param analyzer The analyzer to be used in the function.
- * \param dom_map the domain map of Variables(corresponds to root_iter_vars)
- * \param out_dom_map The output domain.
- * The function is only asked to fill the bounds for Tensors that
- * is already in the out_dom_map
- */
- virtual void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
- /*!
- * \brief Gather the bound from output tensor.
- * Set the range of each root_iter_vars in the op to out_dom_map
- *
- * \param self The reference to self.
- * \param tensor_dom Domain map of Tensor->access set of each dimension.
- * \param out_dom_map The output domain map of each IterVar to be setted.
- */
- virtual void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
- /*!
- * \brief Build the Realize statement that realizes
- * the op's output tensors.
- * \param stage the op's stage.
- * \param realize_map The realization domain map of the operators.
- * \param body The body that is going to get
- * \return A realization statement that wraps body.
- */
- virtual Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const = 0;
- /*!
- * \brief Build the statement that provide the output tensors.
- * \param stage The schedule stage of the op.
- * \param dom_map The domain map of all iteration domains.
- * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
- * \return A statement that add production and wraps consumer.
- */
- virtual Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const = 0;
-
- static constexpr const char* _type_key = "Operation";
-
- TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object);
-};
-
-/*!
- * \brief A placeholder op represents an input placeholder.
- */
-class PlaceholderOpNode : public OperationNode {
- public:
- /*! \brief The shape of the input */
- Array<PrimExpr> shape;
- /*! \brief The data type of the input. */
- DataType dtype;
- // override behavior.
- int num_outputs() const final;
- Array<IterVar> root_iter_vars() const final;
- DataType output_dtype(size_t i) const final;
- Array<PrimExpr> output_shape(size_t i) const final;
- Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const final;
- Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("tag", &tag);
- v->Visit("attrs", &attrs);
- v->Visit("shape", &shape);
- v->Visit("dtype", &dtype);
- }
- static Operation make(std::string name,
- Array<PrimExpr> shape,
- DataType dtype);
-
- static constexpr const char* _type_key = "PlaceholderOp";
- TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
-};
-
-/*!
- * \brief A Compute op that compute a tensor on certain domain.
- * This is the base class for ComputeOp (operating on a scalar at a time) and
- * TensorComputeOp (operating on a TensorSlice at a time)
- */
-class TVM_DLL BaseComputeOpNode : public OperationNode {
- public:
- /*! \brief IterVar on each axis */
- Array<IterVar> axis;
- /*! \brief IterVar on each reduction axis, if the body is a Reduce */
- Array<IterVar> reduce_axis;
- // override functions
- Array<IterVar> root_iter_vars() const final;
- Array<PrimExpr> output_shape(size_t idx) const final;
- void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const final;
- Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const final;
- virtual size_t num_schedulable_dims() const = 0;
-
- static constexpr const char* _type_key = "BaseComputeOp";
- TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
-};
-
-
-/*!
- * \brief A Compute op that compute a tensor on certain domain.
- */
-class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
- public:
- /*! \brief the compute expression */
- Array<PrimExpr> body;
- /*! \brief constructor */
- ComputeOpNode() {}
- // override functions
- int num_outputs() const final;
- DataType output_dtype(size_t i) const final;
- Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
- size_t num_schedulable_dims() const final;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("tag", &tag);
- v->Visit("attrs", &attrs);
- v->Visit("axis", &axis);
- v->Visit("reduce_axis", &reduce_axis);
- v->Visit("body", &body);
- }
- static Operation make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<IterVar> axis,
- Array<PrimExpr> body);
-
- static constexpr const char* _type_key = "ComputeOp";
- TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
-};
-
-/*!
- * \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
- */
-class TensorComputeOpNode : public BaseComputeOpNode {
- public:
- /*! \brief number of axes that can be scheduled */
- int schedulable_ndim;
- /*! \brief TensorIntrin used to compute */
- TensorIntrin intrin;
- /*! \brief input tensors of intrin */
- Array<Tensor> inputs;
- /*! \brief region of input tensors */
- Array<Region> input_regions;
- /*! \brief scalar expression inputs */
- Array<PrimExpr> scalar_inputs;
- /*! \brief constructor */
- TensorComputeOpNode() {}
- // override functions
- int num_outputs() const final;
- DataType output_dtype(size_t i) const final;
- Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
- size_t num_schedulable_dims() const final;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("tag", &tag);
- v->Visit("axis", &axis);
- v->Visit("reduce_axis", &reduce_axis);
- v->Visit("schedulable_ndim", &schedulable_ndim);
- v->Visit("intrin", &intrin);
- v->Visit("inputs", &inputs);
- v->Visit("input_regions", &input_regions);
- v->Visit("scalar_inputs", &scalar_inputs);
- }
- static Operation make(std::string name,
- std::string tag,
- Array<IterVar> axis,
- Array<IterVar> reduce_axis,
- int schedulable_ndim,
- TensorIntrin intrin,
- Array<Tensor> tensors,
- Array<Region> regions,
- Array<PrimExpr> scalar_inputs);
-
- static constexpr const char* _type_key = "TensorComputeOp";
- TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
-};
-
-/*!
- * \brief Symbolic scan.
- */
-class ScanOpNode : public OperationNode {
- public:
- /*! \brief IterVar to scan over */
- IterVar scan_axis;
- /*! \brief the initialization tensors */
- Array<Tensor> init;
- /*! \brief the update function represented by tensor */
- Array<Tensor> update;
- /*! \brief The placeholder to refer as states in update. */
- Array<Tensor> state_placeholder;
- /*!
- * \brief the inputs to the scan, these are optionally provided
- * But they can be helpful to provide hints to speedup get of scan body.
- */
- Array<Tensor> inputs;
- /*!
- * \brief Spatial axis to indicate spatial dimension of each output.
- * They corresponds to flattened spatial axis of the outputs.
- *
- * [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
- * These are auxiliary data structure for storing result of bound inference.
- * They do not corresponds to splittable iterations, thus the name comes
- * with underscore.
- */
- Array<IterVar> spatial_axis_;
- /*! \brief constructor */
- ScanOpNode() {}
- // override behavior.
- int num_outputs() const final;
- Array<IterVar> root_iter_vars() const final;
- DataType output_dtype(size_t i) const final;
- Array<PrimExpr> output_shape(size_t i) const final;
- Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const final;
- Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("tag", &tag);
- v->Visit("attrs", &attrs);
- v->Visit("scan_axis", &scan_axis);
- v->Visit("init", &init);
- v->Visit("update", &update);
- v->Visit("state_placeholder", &state_placeholder);
- v->Visit("inputs", &inputs);
- v->Visit("spatial_axis_", &spatial_axis_);
- }
- static Operation make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- IterVar axis,
- Array<Tensor> init,
- Array<Tensor> update,
- Array<Tensor> state_placeholder,
- Array<Tensor> input);
-
- static constexpr const char* _type_key = "ScanOp";
- TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
-};
-
-/*!
- * \brief External computation that cannot be splitted.
- */
-class ExternOpNode : public OperationNode {
- public:
- /*! \brief The input tensors */
- Array<Tensor> inputs;
- /*! \brief Symbolic placeholder representation of inputs */
- Array<Buffer> input_placeholders;
- /*! \brief Symbolic placeholder representation of outputs */
- Array<Buffer> output_placeholders;
- /*! \brief the statement that generates the computation. */
- Stmt body;
-
- /*! \brief constructor */
- ExternOpNode() {}
- // override functions
- int num_outputs() const final;
- Array<IterVar> root_iter_vars() const final;
- DataType output_dtype(size_t i) const final;
- Array<PrimExpr> output_shape(size_t i) const final;
- Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const final;
- Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("tag", &tag);
- v->Visit("attrs", &attrs);
- v->Visit("inputs", &inputs);
- v->Visit("input_placeholders", &input_placeholders);
- v->Visit("output_placeholders", &output_placeholders);
- v->Visit("body", &body);
- }
- TVM_DLL static Operation make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<Tensor> inputs,
- Array<Buffer> input_placeholders,
- Array<Buffer> output_placeholders,
- Stmt body);
-
- static constexpr const char* _type_key = "ExternOp";
- TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
-};
-
-/*!
- * \brief A computation operator that generated by hybrid script.
- */
-class HybridOpNode : public OperationNode {
- public:
- /*! \brief The input tensors */
- Array<Tensor> inputs;
- /*! \brief Symbolic placeholder representation of outputs */
- Array<Tensor> outputs;
- /*! \brief The axis of iterations */
- Array<IterVar> axis;
- /*! \brief the statement that generates the computation. This is
- * slightly different from the body in ExternOpNode. All the output
- * tensors keep its own name specified by users in the script.
- * However, when compilation, these tensors will be placed by those
- * actual output tensors. */
- Stmt body;
-
- /*! \brief constructor */
- HybridOpNode() {}
- // override functions
- int num_outputs() const final;
- Array<IterVar> root_iter_vars() const final;
- DataType output_dtype(size_t i) const final;
- Array<PrimExpr> output_shape(size_t i) const final;
- Array<Tensor> InputTensors() const final;
- Operation ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const final;
- void PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
- void GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const final;
- Stmt BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const final;
- Stmt BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const final;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("tag", &tag);
- v->Visit("attrs", &attrs);
- v->Visit("inputs", &inputs);
- v->Visit("outputs", &outputs);
- v->Visit("axis", &axis);
- v->Visit("body", &body);
- }
- TVM_DLL static Operation make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<Tensor> inputs,
- Array<Tensor> outputs,
- Stmt body);
-
- static constexpr const char* _type_key = "HybridOp";
- TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
-};
-
-/*! \brief The compute function to specify the input source of a Tensor */
-using FCompute = std::function<PrimExpr (const Array<Var>& i)>;
-
-/*! \brief The compute function to specify the inputs source of Tensors */
-using FBatchCompute = std::function<Array<PrimExpr> (const Array<Var>& i)>;
-
-/*!
- * \brief create a place holder tensor.
- * \param shape The shape of the tensor.
- * \param dtype the data type of the tensor.
- * \param name The name of the Tensor.
- */
-TVM_DLL Tensor placeholder(Array<PrimExpr> shape,
- DataType dtype = DataType::Float(32),
- std::string name = "placeholder");
-
-/*!
- * \brief Construct a new tensor by computing over shape,
- * using the computation rule: result_tensor[axis] = fcompute(axis)
- * \param shape Shape of the tensor.
- * \param fcompute The compute function to create the tensor.
- * \param name The optional name of the tensor.
- * \param tag The optional tag of the tensor.
- * \param attrs Optional additional attributes of the compute.
- */
-TVM_DLL Tensor compute(Array<PrimExpr> shape,
- FCompute fcompute,
- std::string name = "tensor",
- std::string tag = "",
- Map<std::string, ObjectRef> attrs = {});
-
-/*!
- * \brief Construct a new tensor by computing over shape,
- * using the computation rule: result_tensor[axis] = fcompute(axis)
- * \param shape Shape of the tensor.
- * \param fcompute The compute function to create the tensors.
- * \param name The optional name of the tensor.
- * \param tag The optional tag of the tensor.
- * \param attrs Optional additional attributes of the compute.
- */
-TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape,
- FBatchCompute fcompute,
- std::string name = "tensor",
- std::string tag = "",
- Map<std::string, ObjectRef> attrs = {});
-
-/*!
- * \brief Construct new tensors by scan.
- *
- * \param init The intialize tensor of first K steps.
- * \param update The update tensor indicated the updated result after each timestamp.
- * \param state_placeholder The placeholder for the states.
- * \param inputs The inputs to the scan body, this is optional,
- * but recommended to provide concrete information about scan body.
- * \param name The optional name of the tensor.
- * \param tag The optional tag of the tensor.
- * \param attrs Optional additional attributes of the compute.
- */
-TVM_DLL Array<Tensor> scan(Array<Tensor> init,
- Array<Tensor> update,
- Array<Tensor> state_placeholder,
- Array<Tensor> inputs = Array<Tensor>(),
- std::string name = "scan",
- std::string tag = "",
- Map<std::string, ObjectRef> attrs = {});
-
-// same as compute, specialized for different fcompute function
-inline Tensor compute(Array<PrimExpr> shape,
- std::function<PrimExpr(Var)> f,
- std::string name = "tensor",
- std::string tag = "",
- Map<std::string, ObjectRef> attrs = {}) {
- FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
- return compute(shape, fc, name, tag, attrs);
-}
-inline Tensor compute(Array<PrimExpr> shape,
- std::function<PrimExpr(Var, Var)> f,
- std::string name = "tensor",
- std::string tag = "",
- Map<std::string, ObjectRef> attrs = {}) {
- FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
- return compute(shape, fc, name, tag, attrs);
-}
-inline Tensor compute(Array<PrimExpr> shape,
- std::function<PrimExpr(Var, Var, Var)> f,
- std::string name = "tensor",
- std::string tag = "",
- Map<std::string, ObjectRef> attrs = {}) {
- FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
- return compute(shape, fc, name, tag, attrs);
-}
-inline Tensor compute(Array<PrimExpr> shape,
- std::function<PrimExpr(Var, Var, Var, Var)> f,
- std::string name = "tensor",
- std::string tag = "",
- Map<std::string, ObjectRef> attrs = {}) {
- FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
- return compute(shape, fc, name, tag, attrs);
-}
-
-// inline function.
-inline const OperationNode* Operation::operator->() const {
- return static_cast<const OperationNode*>(get());
-}
-} // namespace tvm
-#endif // TVM_OPERATION_H_
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_
+#include <tvm/top/tensor.h>
+
#include <string>
#include <memory>
#include <limits>
#include <type_traits>
#include "expr.h"
-#include "tensor.h"
#include "runtime/packed_func.h"
namespace tvm {
if (ptr->IsInstance<IterVarNode>()) {
return IterVar(ObjectPtr<Object>(ptr))->var;
}
- if (ptr->IsInstance<TensorNode>()) {
- return Tensor(ObjectPtr<Object>(ptr))();
+ if (ptr->IsInstance<top::TensorNode>()) {
+ return top::Tensor(ObjectPtr<Object>(ptr))();
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
#ifndef TVM_RELAY_OP_ATTR_TYPES_H_
#define TVM_RELAY_OP_ATTR_TYPES_H_
-#include <tvm/tensor.h>
-#include <tvm/schedule.h>
+#include <tvm/top/tensor.h>
+#include <tvm/top/schedule.h>
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
* \return The output compute description of the operator.
*/
using FTVMCompute = runtime::TypedPackedFunc<
- Array<Tensor>(const Attrs& attrs,
- const Array<Tensor>& inputs,
- const Type& out_type,
- const Target& target)>;
+ Array<top::Tensor>(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
+ const Type& out_type,
+ const Target& target)>;
/*!
* \brief Build the computation schedule for
* \return schedule The computation schedule.
*/
using FTVMSchedule = runtime::TypedPackedFunc<
- Schedule(const Attrs& attrs,
- const Array<Tensor>& outs,
- const Target& target)>;
+ top::Schedule(const Attrs& attrs,
+ const Array<top::Tensor>& outs,
+ const Target& target)>;
/*!
* \brief Alternate the layout of operators or replace the
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
- const Array<Tensor>& tinfos)>;
+ const Array<top::Tensor>& tinfos)>;
/*!
* \brief Convert the layout of operators or replace the
using FTVMConvertOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
- const Array<Tensor>& tinfos,
+ const Array<top::Tensor>& tinfos,
const std::string& desired_layout)>;
/*!
* \brief Legalizes an expression with another expression. This function will be
using Shape = Array<IndexExpr>;
using FShapeFunc = runtime::TypedPackedFunc<
- Array<Tensor>(const Attrs& attrs,
- const Array<Tensor>& inputs,
- const Array<IndexExpr>& out_ndims)>;
+ Array<top::Tensor>(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
+ const Array<IndexExpr>& out_ndims)>;
} // namespace relay
} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/schedule.h
- * \brief Define a schedule.
- */
-// Acknowledgement: Many schedule primitives originate from Halide and Loopy.
-#ifndef TVM_SCHEDULE_H_
-#define TVM_SCHEDULE_H_
-
-#include <string>
-#include <unordered_map>
-#include "expr.h"
-#include "tensor.h"
-#include "tensor_intrin.h"
-
-namespace tvm {
-
-// Node container for Stage
-class StageNode;
-// Node container for Schedule
-class ScheduleNode;
-// Node container for IterVarRelation
-class IterVarRelationNode;
-// Attribute of itervar.
-class IterVarAttrNode;
-
-/*! \brief the attachment type */
-enum AttachType : int {
- kGroupRoot = 1,
- kInline = 2,
- kInlinedAlready = 3,
- kScope = 4,
- kScanUpdate = 5
-};
-
-/*! \brief Stage, contains scheduling for a stage of computation. */
-class Stage : public ObjectRef {
- public:
- Stage() {}
- explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
- /*!
- * \brief create a new schedule for op.
- * \param op The operator in the schedule
- */
- explicit Stage(Operation op);
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const StageNode* operator->() const;
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline StageNode* operator->();
- /*!
- * \brief set the memory scope of the stage
- * \param scope The memory scope.
- */
- TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*)
- /*!
- * \brief specify the schedule to be computed at the parent schedule's scope.
- * \param parent The parent schedule.
- * \param scope The iteration point to carry the schedule.
- * \return reference to self.
- */
- TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
- /*!
- * \brief Compute the function inline.
- * \return reference to self.
- */
- TVM_DLL Stage& compute_inline(); // NOLINT(*)
- /*!
- * \brief Compute the function at group root.
- * \return reference to self.
- */
- TVM_DLL Stage& compute_root(); // NOLINT(*)
- /*!
- * \brief Bind the IterVar to thread index.
- *
- * \param ivar The IterVar to be bound.
- * \param thread_ivar The thread axis to be bound.
- * \return reference to self.
- */
- TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
- /*!
- * \brief Set the predicate to determine whether a store to the array should be performed.
- * Use this when there are multiple threads performing the same store and we only
- * need one of them to do the store.
- *
- * \note This is a dangerous scheduling primitive that can change behavior of program.
- * Only do when we are certain that thare are duplicated stores.
- * \param predicate The condition to be checked.
- * \return reference to self.
- */
- TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
- /*!
- * \brief Specify environment threads that launched around the group's scope.
- * This can only be used in group stage.
- * \param threads The threads to be launched around the scope.
- * \note Each thread can only appear in one env_threads.
- * This is a beta feature.
- * \return reference to self.
- */
- TVM_DLL Stage& env_threads(Array<IterVar> threads);
- /*!
- * \brief Split the parent by factor, generate
- * \param parent The parent iteration domain.
- * \param factor The split factor of the loop.
- * \param p_outer The result outer domain
- * \param p_inner The result inner domain.
- * \return reference to self.
- */
- TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
- /*!
- * \brief Split the iteration with given number of parts.
- *
- * \param parent The parent domain.
- * \param nparts The number of parts in the outer domain.
- * \param p_outer The result outer domain.
- * \param p_inner The result inner domain.
- * \return reference to self.
- */
- TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
- /*!
- * \brief Fuse the inner outer domain to the target
- * \param outer The outer domain to be fused.
- * \param inner The inner domain to be fused
- * \param p_target The result target domain.
- * \return reference to self.
- */
- TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
- /*!
- * \brief Fuse all the axes together into a single axis.
- *
- * \param axes All the axes to be fused.
- * \param p_target The result target domain.
- *
- * \note axes can be an empty array,
- * in that case, a singleton IterVar is created and
- * inserted to the outermost loop.
- * The fuse of empty array is used to support zero-dimension tensors.
- *
- * \return reference to self.
- */
- TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
- /*!
- * \brief Reorder the iteration
- * \param order The order of iteration variable.
- * \return reference to self.
- */
- TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
- /*!
- * \brief Perform tiling on two dimensions
- * The final loop order from outmost to inner most are
- * [x_outer, y_outer, x_inner, y_inner]
- *
- * \param x_parent The original x dimension
- * \param y_parent The original y dimension
- * \param x_factor The stride factor on x axis
- * \param y_factor The stride factor on y axis
- * \param p_x_outer Outer axis of x dimension
- * \param p_y_outer Outer axis of y dimension
- * \param p_x_inner Inner axis of x dimension
- * \param p_y_inner Inner axis of y dimension
- * \return reference to self.
- */
- TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
- PrimExpr x_factor, PrimExpr y_factor,
- IterVar* p_x_outer, IterVar* p_y_outer,
- IterVar* p_x_inner, IterVar* p_y_inner);
- /*!
- * \brief Vectorize iteration.
- * \param var The axis to be vectorized.
- * \return reference to self.
- */
- TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
- /*!
- * \brief Replace computation of the current stage by tensor intrinsic f.
- * \param var The axis marks beginning of tensorization.
- * Every operations inside the axis(include axis itself is tensorized).
- * \param f The Tensor compute intrinsics.
- * \return reference to self.
- */
- TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
- /*!
- * \brief Unroll iteration.
- * \param var The axis to be unrolled.
- * \return reference to self.
- */
- TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
- /*!
- * \brief Parallelize iteration.
- * \param var The axis to be parallelized.
- * \return reference to self.
- */
- TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
- /*!
- * \brief Annotate the iteration with pragma
- *
- * \param var The axis to be parallelized.
- * \param pragma_type The pragma type.
- * \param pragma_value The pragma value
- *
- * \return reference to self.
- */
- TVM_DLL Stage& pragma(IterVar var,
- const std::string& pragma_type,
- const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*)
- /*!
- * \brief Fetch data in advance.
- * \param domain the tensor to be prefetched
- * \param var the iteration point at which to apply prefetching
- * \param offset the number of iterations be to fetched in advance
- * \return reference to self
- */
- TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*)
- /*!
- * \brief Set alignment requirement for specific dimension.
- *
- * Such that stride[axis] == k * factor + offset for some k.
- *
- * \param axis The dimension to be specified for alignment.
- * \param factor The factor multiple of alignment
- * \param offset The required offset factor.
- * \return reference to self
- */
- TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
- /*!
- * \brief Compute current stage with double buffering.
- * \return reference to self.
- */
- TVM_DLL Stage& double_buffer(); // NOLINT(*)
- /*!
- * \brief Schedule for OpenGL fragment shader.
- * \return reference to self.
- */
- Stage& opengl(); // NOLINT(*)
- /*!
- * \brief whether the stage has been scheduled.
- * \return whether the stage has been scheduled.
- */
- bool is_scheduled() const;
- /*!
- * \brief Get attachment spec of current stage.
- * If the stage compute at Group root, this function
- * will traverse the group function to get the
- * final spec from the group.
- * \return A stage representing the attach spec of the group.
- */
- Stage GetAttachSpec() const;
- // declare container type
- using ContainerType = StageNode;
-};
-
-/*!
- * \brief Global schedule container
- * For operations and all the operations they depend on.
- * The schedule per Operation is named as stage.
- */
-class Schedule : public ObjectRef {
- public:
- Schedule() {}
- explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
- /*!
- * \brief Get a copy of current schedule.
- * \return The copied schedule.
- */
- Schedule copy() const;
- /*!
- * \brief Get the stage corresponds to the op
- * \param op The operation.
- */
- TVM_DLL Stage operator[](const Operation& op);
- /*!
- * \brief Short hand for getting the stage of tensor's operation.
- * \param tensor The tensor
- * \return The stage corresponding to the tensor's op
- */
- TVM_DLL Stage operator[](const Tensor& tensor) {
- return this->operator[](tensor->op);
- }
- /*!
- * \brief Create a new stage group for all intermediate
- * operations between inputs and outputs.
- *
- * \param outputs The output boundary of the group.
- * \param inputs The input boundary of the group.
- * \param include_inputs Whether include inputs if they are reachable from outputs.
- * \return The new grouped stage.
- */
- TVM_DLL Stage create_group(const Array<Tensor>& outputs,
- const Array<Tensor>& inputs,
- bool include_inputs = false);
- /*!
- * \brief create a cache read of original tensor for readers.
- * This will mutate the body of the readers.
- * A new stage will be created for the tensor.
- * \param tensor The tensor cached.
- * \param scope The scope of the cache.
- * \param readers The readers to redirect to the tensor.
- * \return The created tensor.
- */
- TVM_DLL Tensor cache_read(const Tensor& tensor,
- const std::string& scope,
- const Array<Operation>& readers);
- /*!
- * \brief Create a cache write tensor for producing tensor.
- * The the tensor will take over body of original tensor op.
- *
- * This function can be used to do data layout transformation.
- * If there is a split/fuse/reorder on the data parallel axis of tensor
- * before cache_write is called. The intermediate cache stores
- * the data in the layout as the iteration order of leave axis.
- * The data will be transformed back to the original layout in the original tensor.
- * User can further call compute_inline to inline the original layout and keep
- * the data stored in the transformed layout.
- *
- * \param tensor The tensors to be produced.
- * \param scope The scope of the storage.
- * \return The created tensor.
- */
- TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
- /*!
- * \brief Create a cache write tensor for producing tensor.
- * The the tensor will take over body of original tensor op.
- *
- * This function can be used to do data layout transformation.
- * If there is a split/fuse/reorder on the data parallel axis of tensor
- * before cache_write is called. The intermediate cache stores
- * the data in the layout as the iteration order of leave axis.
- * The data will be transformed back to the original layout in the original tensor.
- * User can further call compute_inline to inline the original layout and keep
- * the data stored in the transformed layout.
- *
- * \param tensor The tensor to be produced.
- * \param scope The scope of the storage.
- * \return The created tensor.
- */
- TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
- /*!
- * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
- * This will create a new stage that generated the new tensor with axis
- * as the first dimension. The tensor's body will be rewritten as a reduction
- * over the factored tensor.
- *
- * P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
- *
- * \param tensor The tensor to be factored.
- * \param axis The reduction axis in tensor's schedule to be factored.
- * \param factor_axis The position where the new axis is placed.
- * \return The created factored tensors.
- */
- TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
- const IterVar& axis,
- int factor_axis = 0);
- /*!
- * \brief Normalize the schedule.
- * This is needed before bound inference.
- * Insert necessary RebaseNode to make sure all leaf_iter_vars
- * are in form [0, extent)
- *
- * \return A normalized schedule, can be same as current one.
- */
- Schedule normalize();
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const ScheduleNode* operator->() const;
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline ScheduleNode* operator->();
- // declare container type
- using ContainerType = ScheduleNode;
-};
-
-/*!
- * \brief The schedule relation between IterVars
- * can be Split, Fuse.
- */
-class IterVarRelation : public ObjectRef {
- public:
- IterVarRelation() {}
- explicit IterVarRelation(ObjectPtr<Object> n) : ObjectRef(n) {}
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const IterVarRelationNode* operator->() const;
-};
-
-/*!
- * \brief Additional scheduable attributes about IterVar.
- */
-class IterVarAttr : public ObjectRef {
- public:
- IterVarAttr() {}
- explicit IterVarAttr(ObjectPtr<Object> n) : ObjectRef(n) {}
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const IterVarAttrNode* operator->() const;
-};
-
-/*!
- * \brief represents a stage.
- *
- * relations form a Directed acylic hypergraph in bipartite manner.
- * With each node is represented by a IterVar,
- * and each hyper-edge is represented by a IterVarRelation.
- * The relations connects the IterVars in the graph.
- *
- * Besides typical stage that corresponds to operations.
- * There is also group stage, which groups stages together.
- * Each stage's group(given by group) represent an constraint,
- * the stage can only be attached to stages within the group.
- *
- * The group stage node can be attached to IterVars as in normal stage.
- */
-class StageNode : public Object {
- public:
- /*!
- * \brief The operation of stage, can be different from original op.
- * If it is null, then this stage is a group stage.
- */
- Operation op;
- /*!
- * \brief The original operator.
- * The op field can change during schedule to alternate the dataflow,
- * while origin_op remains fixed.
- */
- Operation origin_op;
- /*! \brief All the nodes in the iter var */
- Array<IterVar> all_iter_vars;
- /*! \brief The current active leaf iter vars in the stage. */
- Array<IterVar> leaf_iter_vars;
- /*!
- * \brief Specify threads to be launched at the stage.
- * This is only valid for composite ops such as Scan.
- * \note Experimental primitive: used for thread persistence.
- */
- Array<IterVar> env_threads;
- /*!
- * \brief The predicate under which store can happen
- * Use this when there can be duplicated threads doing the same store.
- * \note Experimental primitive: used by cross thread-reduction.
- */
- PrimExpr store_predicate;
- /*! \brief The relation bwteen of IterVars */
- Array<IterVarRelation> relations;
- /*! \brief additional attributes about iter var. */
- Map<IterVar, IterVarAttr> iter_var_attrs;
- /*! \brief The attachment type of the schedule */
- AttachType attach_type{kGroupRoot};
- /*! \brief The attach point of this schedule. */
- IterVar attach_ivar;
- /*! \brief The stage this node attaches to */
- Stage attach_stage;
- /*! \brief The thread storage scope level of the stage */
- std::string scope;
- /*! \brief Whether this is an output stage */
- bool is_output{false};
- /*! \brief Whether this is an OpenGL stage */
- bool is_opengl{false};
- /*! \brief Whether apply double buffer optimization to this stage */
- bool double_buffer{false};
- /*!
- * \brief The parent group of the current stage.
- * The stage cannot be assigned to stages outside the group.
- */
- Stage group;
- /*! \brief Number of direct child stages, only used for group stage.*/
- int num_child_stages{0};
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("op", &op);
- v->Visit("origin_op", &origin_op);
- v->Visit("all_iter_vars", &all_iter_vars);
- v->Visit("leaf_iter_vars", &leaf_iter_vars);
- v->Visit("env_threads", &env_threads);
- v->Visit("relations", &relations);
- v->Visit("iter_var_attrs", &iter_var_attrs);
- v->Visit("attach_type", &attach_type);
- v->Visit("attach_ivar", &attach_ivar);
- v->Visit("attach_stage", &attach_stage);
- v->Visit("scope", &scope);
- v->Visit("is_output", &is_output);
- v->Visit("is_opengl", &is_opengl);
- v->Visit("double_buffer", &double_buffer);
- v->Visit("group", &group);
- v->Visit("num_child_stages", &num_child_stages);
- }
-
- static constexpr const char* _type_key = "Stage";
- TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
-};
-
-/*! \brief node container for schedule */
-class ScheduleNode : public Object {
- public:
- /*! \brief The output operations in original data flow graph */
- Array<Operation> outputs;
- /*!
- * \brief list of all stages for ops.
- * The stages are sorted in dependency order.
- */
- Array<Stage> stages;
- /*!
- * \brief List of all stage groups.
- */
- Array<Stage> groups;
- /*! \brief map of original operation to the stages */
- Map<Operation, Stage> stage_map;
- /*!
- * \brief Internal stage map to map internal ops to stages.
- * This is created on demand and can be invalidated.
- */
- std::unordered_map<const Object*, Stage> op2stage_cache_;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("outputs", &outputs);
- v->Visit("stages", &stages);
- v->Visit("groups", &groups);
- v->Visit("stage_map", &stage_map);
- }
-
- /*! \brief Initialize temp cache. */
- void InitCache();
- /*! \brief Invalidate temp cache. */
- void InvalidateCache();
-
- /*!
- * \brief Check if the schedule contains an Operation.
- * \param op The candidate Operation.
- * \return true if the schedule has the Operation. Otherwise, false.
- */
- TVM_DLL bool Contain(const Operation& op) const;
-
- /*!
- * \brief Check if the schedule contains a Tensor.
- * \param tensor The candidate tensor.
- * \return true if the schedule has the tensor. Otherwise, false.
- */
- TVM_DLL bool Contain(const Tensor& tensor) const {
- return Contain(tensor->op);
- }
-
- /*!
- * \brief Create a schedule for array of ops(and their dependencies).
- * \param ops The ops to be scheduled.
- * \return sch The created Schedule.
- */
- TVM_DLL static Schedule make(Array<Operation> ops);
-
- static constexpr const char* _type_key = "Schedule";
- TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
-};
-
-/*!
- * \brief Create a schedule for array of ops(and their dependencies).
- * \param ops The ops to be scheduled.
- * \return sch The created Schedule.
- */
-inline Schedule create_schedule(Array<Operation> ops) {
- return ScheduleNode::make(ops);
-}
-
-/*! \brief node container for IterVar attr */
-class IterVarAttrNode : public Object {
- public:
- /*! \brief The iteration type. */
- IterVarType iter_type{kDataPar};
- /*! \brief The thread this iter Var binds, can be null */
- IterVar bind_thread;
- /*! \brief List of tensor to be prefetched in this loop */
- Array<Tensor> prefetch_data;
- /*! \brief The offset used in each prefetch */
- Array<PrimExpr> prefetch_offset;
- /*!
- * \brief Tensor intrinsic used in tensorization,
- * when the axis is marked as Tensorized
- */
- TensorIntrin tensor_intrin;
- /*! \brief Alignment factor of buffer dimension */
- int dim_align_factor{0};
- /*! \brief Alignment offset of buffer dimension */
- int dim_align_offset{0};
- /*!
- * \brief Additional pragma keys, array of StringImm
- */
- Array<PrimExpr> pragma_keys;
- /*!
- * \brief Additional values of pragma, if any
- */
- Array<PrimExpr> pragma_values;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("iter_type", &iter_type);
- v->Visit("bind_thread", &bind_thread);
- v->Visit("prefetch_data", &prefetch_data);
- v->Visit("prefetch_offset", &prefetch_offset);
- v->Visit("tensor_intrin", &tensor_intrin);
- v->Visit("dim_align_factor", &dim_align_factor);
- v->Visit("dim_align_offset", &dim_align_offset);
- v->Visit("pragma_keys", &pragma_keys);
- v->Visit("pragma_values", &pragma_values);
- }
-
- static constexpr const char* _type_key = "IterVarAttr";
- TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object);
-};
-
-/*! \brief base node of iteration var */
-class IterVarRelationNode : public Object {
- public:
- static constexpr const char* _type_key = "IterVarRelation";
- TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object);
-};
-
-/*!
- * \brief Split the parent domain into product of
- * outer and iter.
- */
-class SplitNode : public IterVarRelationNode {
- public:
- /*! \brief The parent domain */
- IterVar parent;
- /*! \brief The outer domain */
- IterVar outer;
- /*! \brief The inner domain */
- IterVar inner;
- /*! \brief The split factor */
- PrimExpr factor;
- /*! \brief Number of parts, only factor or nparts can be given */
- PrimExpr nparts;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("parent", &parent);
- v->Visit("outer", &outer);
- v->Visit("inner", &inner);
- v->Visit("factor", &factor);
- v->Visit("nparts", &nparts);
- }
-
- static IterVarRelation make(IterVar parent,
- IterVar outer,
- IterVar inner,
- PrimExpr factor,
- PrimExpr nparts);
-
- static constexpr const char* _type_key = "Split";
- TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
-};
-
-/*!
- * \brief Fuse two domains into one domain.
- */
-class FuseNode : public IterVarRelationNode {
- public:
- /*! \brief The outer domain */
- IterVar outer;
- /*! \brief The inner domain */
- IterVar inner;
- /*! \brief The target domain */
- IterVar fused;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("outer", &outer);
- v->Visit("inner", &inner);
- v->Visit("fused", &fused);
- }
-
- static IterVarRelation make(
- IterVar outer, IterVar inner, IterVar fused);
-
- static constexpr const char* _type_key = "Fuse";
- TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
-};
-
-/*!
- * \brief Rebase the iteration to make min to be 0.
- * This is useful to normalize the Schedule
- * to make every leaf variable's min to be 0.
- */
-class RebaseNode : public IterVarRelationNode {
- public:
- /*! \brief The parent domain */
- IterVar parent;
- /*! \brief The inner domain */
- IterVar rebased;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("parent", &parent);
- v->Visit("rebased", &rebased);
- }
-
- static IterVarRelation make(IterVar parent, IterVar rebased);
-
- static constexpr const char* _type_key = "Rebase";
- TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
-};
-
-
-/*!
- * \brief Singleton iterator [0, 1)
- */
-class SingletonNode : public IterVarRelationNode {
- public:
- /*! \brief The singleton iterator */
- IterVar iter;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("iter", &iter);
- }
-
- static IterVarRelation make(IterVar iter);
-
- static constexpr const char* _type_key = "Singleton";
- TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
-};
-
-
-// implementations
-inline const StageNode* Stage::operator->() const {
- return static_cast<const StageNode*>(get());
-}
-inline StageNode* Stage::operator->() {
- return static_cast<StageNode*>(get_mutable());
-}
-
-inline const ScheduleNode* Schedule::operator->() const {
- return static_cast<const ScheduleNode*>(get());
-}
-inline ScheduleNode* Schedule::operator->() {
- return static_cast<ScheduleNode*>(get_mutable());
-}
-
-inline const IterVarRelationNode* IterVarRelation::operator->() const {
- return static_cast<const IterVarRelationNode*>(get());
-}
-
-inline const IterVarAttrNode* IterVarAttr::operator->() const {
- return static_cast<const IterVarAttrNode*>(get());
-}
-} // namespace tvm
-#endif // TVM_SCHEDULE_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/schedule_pass.h
- * \brief Collection of Schedule pass functions.
- *
- * These passes works on the schedule hyper-graph
- * and infers information such as bounds, check conditions
- * read/write dependencies between the IterVar
- */
-#ifndef TVM_SCHEDULE_PASS_H_
-#define TVM_SCHEDULE_PASS_H_
-
-#include "schedule.h"
-
-namespace tvm {
-namespace schedule {
-
-/*!
- * \brief Infer the bound of all iteration variables relates to the schedule.
- *
- * \param sch The root schedule to infer all the bounds.
- * \return the result bound of the iteration Variable
- */
-Map<IterVar, Range> InferBound(const Schedule& sch);
-
-/*!
- * \brief Schedule s' dependent operations.
- *
- * \param s The schedule to be realized
- * \param dom_map The domain of each iter vars.
- * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 during lowering.
- * This is a debug feature for dataflow/axis analysis.
- * Note: If this is true, The lowered IR may be incorrect,
- * because we will also delete the init part of reduction
- * \return the result Stmt
- */
-Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
-
-/*!
- * \brief To automatically inline the element-wise operations.
- *
- * \param sch The schedule to be inlined.
- */
-void AutoInlineElemWise(Schedule sch);
-
-/*!
- * \brief To automatically inline operations with injective writes
- * (i.e. writes without reduction or sequential loops). Note
- * that in this case, guarantees about contiguity, transpose, stride,
- * alignemnt and memory footprint in general do not hold.
- *
- * \param sch The schedule to be inlined.
- */
-TVM_DLL void AutoInlineInjective(Schedule sch);
-
-} // namespace schedule
-} // namespace tvm
-#endif // TVM_SCHEDULE_PASS_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/tensor.h
- * \brief Dataflow tensor object
- */
-#ifndef TVM_TENSOR_H_
-#define TVM_TENSOR_H_
-
-#include <tvm/node/container.h>
-#include <tvm/arith/bound.h>
-
-#include <string>
-#include <vector>
-#include <utility>
-#include <type_traits>
-
-#include "expr.h"
-#include "expr_operator.h"
-
-namespace tvm {
-
-// Internal node container of Tensor
-class TensorNode;
-// internal node container for Operation
-class OperationNode;
-
-/*!
- * \brief Tensor structure representing a possible input,
- * or intermediate computation result.
- */
-class Tensor : public ObjectRef {
- public:
- /*! \brief default constructor, used internally */
- Tensor() {}
- explicit Tensor(ObjectPtr<Object> n) : ObjectRef(n) {}
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const TensorNode* operator->() const;
- /*!
- * \brief check if two tensors equals each other.
- * \param other tensor to be checked.
- * \return whether the two tensors equals each other.
- */
- inline bool operator==(const Tensor& other) const;
- /*!
- * \brief check if two tensors are different.
- * \param other tensor to be checked.
- * \return whether the two tensors are different.
- */
- inline bool operator!=(const Tensor& other) const;
- /*! \return The dimension of the tensor */
- inline size_t ndim() const;
- /*!
- * \brief Take elements from the tensor
- * \param args The indices
- * \return the result expression representing tensor read.
- */
- template<typename... Args>
- inline PrimExpr operator()(Args&& ...args) const {
- Array<PrimExpr> indices{std::forward<Args>(args)...};
- return operator()(indices);
- }
- /*!
- * \brief Take elements from the tensor
- * \param indices the indices.
- * \return the result expression representing tensor read.
- */
- TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const;
- /*!
- * \brief Take elements from the tensor
- * \param indices the indices.
- * \return the result expression representing tensor read.
- */
- TVM_DLL PrimExpr operator()(Array<Var> indices) const;
- /*!
- * \brief data structure to represent a slice that fixes first k coordinates.
- * This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
- */
- class Slice {
- public:
- // construct via tensor and indices
- Slice(const Tensor& tensor, std::vector<PrimExpr> indices)
- : tensor_(tensor), indices_(indices) {}
- /*!
- * \brief get i-th slice from the current slice.
- * \param i the index of the coordinate
- * \return the subsequent slice.
- */
- inline Slice operator[](PrimExpr i) {
- std::vector<PrimExpr> other = indices_;
- other.emplace_back(i);
- return Slice(tensor_, other);
- }
- /*!
- * \brief Convert slice to expression.
- * This is only valid when all the coordinates are fully specified.
- * \return the corresponding expression of this slice.
- */
- inline operator PrimExpr() const {
- return tensor_(indices_);
- }
-
- private:
- const Tensor& tensor_;
- std::vector<PrimExpr> indices_;
- };
- /*!
- * \brief get i-th slice from the current Tensor.
- * \param i the index of the coordinate
- * \return the subsequent slice.
- */
- inline Slice operator[](PrimExpr i) const {
- return Slice(*this, {i});
- }
- /*! \brief specify container node */
- using ContainerType = TensorNode;
-};
-
-/*! \brief Operation that produces tensors */
-class Operation : public ir::FunctionRef {
- public:
- /*! \brief default constructor */
- Operation() {}
- explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const OperationNode* operator->() const;
- /*!
- * \brief get the i-th output of the operation.
- * \param i the output index.
- * \return The i-th output.
- */
- TVM_DLL Tensor output(size_t i) const;
- /*! \brief specify container node */
- using ContainerType = OperationNode;
-};
-
-/*! \brief Node to represent a tensor */
-class TensorNode : public Object {
- public:
- /*! \brief The shape of the tensor */
- Array<PrimExpr> shape;
- /*! \brief data type in the content of the tensor */
- DataType dtype;
- /*! \brief the source operation, can be None */
- Operation op;
- /*! \brief the output index from source operation */
- int value_index{0};
- /*! \brief constructor */
- TensorNode() {}
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("shape", &shape);
- v->Visit("dtype", &dtype);
- v->Visit("op", &op);
- v->Visit("value_index", &value_index);
- }
- TVM_DLL static Tensor make(Array<PrimExpr> shape,
- DataType dtype,
- Operation op,
- int value_index);
-
- static constexpr const char* _type_key = "Tensor";
- TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
-};
-
-
-// Implementations of inline functions
-inline const TensorNode* Tensor::operator->() const {
- return static_cast<const TensorNode*>(get());
-}
-
-inline size_t Tensor::ndim() const {
- return (*this)->shape.size();
-}
-
-inline bool Tensor::operator==(const Tensor& other) const {
- if (get() == other.get()) return true;
- if (get() == nullptr || other.get() == nullptr) return false;
- if ((*this)->op.defined() || other->op.defined()) {
- return (*this)->op == other->op &&
- (*this)->value_index == other->value_index;
- } else {
- return false;
- }
-}
-
-inline bool Tensor::operator!=(const Tensor& other) const {
- return !(*this == other);
-}
-
-// macro to turn every operation of slice to expression
-#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
- inline PrimExpr operator Op (const Tensor::Slice& a) { \
- return Op a.operator PrimExpr() ; \
- } \
-
-#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
- template<typename T> \
- inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \
- return a.operator PrimExpr() Op b; \
- } \
- template<typename T> \
- inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \
- return a Op b.operator PrimExpr(); \
- } \
- inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
- return a.operator PrimExpr() Op b.operator PrimExpr(); \
- }
-
-DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
-DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(!=);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(&&);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(||);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(>>);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(<<);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(>); // NOLINT(*)
-DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*)
-
-} // namespace tvm
-
-namespace std {
-template <>
-struct hash<::tvm::Operation> : public ::tvm::ObjectHash {
-};
-
-template <>
-struct hash<::tvm::Tensor> {
- std::size_t operator()(const ::tvm::Tensor& k) const {
- ::tvm::ObjectHash hasher;
- if (k.defined() && k->op.defined()) {
- return hasher(k->op);
- } else{
- return hasher(k);
- }
- }
-};
-} // namespace std
-#endif // TVM_TENSOR_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/tensor_intrin.h
- * \brief Tensor intrinsic operations.
- */
-#ifndef TVM_TENSOR_INTRIN_H_
-#define TVM_TENSOR_INTRIN_H_
-
-#include <string>
-#include "tensor.h"
-#include "buffer.h"
-
-namespace tvm {
-
-// Internal node container of tensor intrinsics.
-class TensorIntrinNode;
-
-/*! \brief Tensor intrinsic node. */
-class TensorIntrin : public ObjectRef {
- public:
- TensorIntrin() {}
- explicit TensorIntrin(ObjectPtr<Object> n) : ObjectRef(n) {}
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const TensorIntrinNode* operator->() const;
-
- /*! \brief specify container node */
- using ContainerType = TensorIntrinNode;
-};
-
-/*! \brief Node to represent a Tensor intrinsic operator */
-class TensorIntrinNode : public Object {
- public:
- /*! \brief The name of the intrinsic */
- std::string name;
- /*! \brief The operation this intrinsics is carrying out */
- Operation op;
- /*! \brief List of inputs of operator, placeholder in postdfs order */
- Array<Tensor> inputs;
- /*!
- * \brief Symbolic buffers of each output/input tensor
- * buffers[0:len(inputs)] are buffers of the inputs.
- * buffers[len(inputs):] are buffers of each output.
- *
- * \note When a field in Buffer is Var, it means we can be flexible
- * wrt that field and Var can occur in body.
- * When it is a constant, it means we can only take data in that shape.
- */
- Array<Buffer> buffers;
- /*! \brief List of scalar variables, used in body. These placeholders
- * will be bound to expressions passed in when the TensorIntrin is called
- * from a TensorComputeOp.
- */
- Array<Var> scalar_params;
- /*! \brief The normal statement to execute the intrinsic */
- Stmt body;
- /*!
- * \brief Special statement for reduction op, can be None
- * reset the value of output buffer to identity value.
- */
- Stmt reduce_init;
- /*!
- * \brief Special statement for reduction op, can be None
- * Reduce: do a reduction of current output buffer with the result.
- */
- Stmt reduce_update;
- /*! \brief constructor */
- TensorIntrinNode() {}
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("op", &op);
- v->Visit("inputs", &inputs);
- v->Visit("buffers", &buffers);
- v->Visit("scalar_params", &scalar_params);
- v->Visit("body", &body);
- v->Visit("reduce_init", &reduce_init);
- v->Visit("reduce_update", &reduce_update);
- }
-
- TVM_DLL static TensorIntrin make(std::string name,
- Operation op,
- Array<Tensor> inputs,
- Array<Buffer> buffers,
- Array<Var> scalar_params,
- Stmt body,
- Stmt reduce_init,
- Stmt reduce_update);
-
- static constexpr const char* _type_key = "TensorIntrin";
- TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
-};
-
-inline const TensorIntrinNode* TensorIntrin::operator->() const {
- return static_cast<const TensorIntrinNode*>(get());
-}
-
-// Internal node container of tensor intrinsic calling.
-class TensorIntrinCallNode;
-
-/*! \brief Tensor intrinsic calling node. */
-class TensorIntrinCall : public ObjectRef {
- public:
- TensorIntrinCall() {}
- explicit TensorIntrinCall(ObjectPtr<Object> n) : ObjectRef(n) {}
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const TensorIntrinCallNode* operator->() const;
-
- /*! \brief specify container node */
- using ContainerType = TensorIntrinCallNode;
-};
-
-class TensorIntrinCallNode : public Object {
- public:
- /*! \brief the tensor intrinsic */
- TensorIntrin intrin;
- /*! \brief input tensors of the intrinsic */
- Array<Tensor> tensors;
- /*! \brief regions of input tensors */
- Array<Region> regions;
-
-
- /*!
- * \brief IterVar on each reduction axis, if the
- * intrin will use the reduce axis
- */
- Array<IterVar> reduce_axis;
-
- /*! \brief scalar expression inputs */
- Array<PrimExpr> scalar_inputs;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("intrin", &intrin);
- v->Visit("tensors", &tensors);
- v->Visit("regions", ®ions);
- v->Visit("reduce_axis", &reduce_axis);
- v->Visit("scalar_inputs", &scalar_inputs);
- }
- static TensorIntrinCall make(TensorIntrin intrin,
- Array<Tensor> tensors,
- Array<Region> regions,
- Array<IterVar> reduce_axis,
- Array<PrimExpr> scalar_inputs);
-
- static constexpr const char* _type_key = "TensorIntrinCall";
- TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object);
-};
-
-inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
- return static_cast<const TensorIntrinCallNode*>(get());
-}
-
-} // namespace tvm
-#endif // TVM_TENSOR_INTRIN_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/top/operation.h
+ * \brief Operation node can generate one or multiple Tensors
+ */
+#ifndef TVM_TOP_OPERATION_H_
+#define TVM_TOP_OPERATION_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/top/tensor.h>
+#include <tvm/top/schedule.h>
+
+#include <tvm/expr.h>
+#include <tvm/expr_operator.h>
+#include <tvm/buffer.h>
+
+#include <string>
+#include <vector>
+#include <unordered_map>
+
+
+
+namespace tvm {
+namespace top {
+
+using arith::IntSet;
+
+/*!
+ * \brief Temporary data structure to store union
+ * of bounds of each axis of Tensor.
+ */
+struct TensorDom {
+ // constructor
+ explicit TensorDom(int ndim)
+ : data(ndim) {}
+ /*! \brief The domain data */
+ std::vector<std::vector<IntSet> > data;
+};
+
+/*!
+ * \brief Base class of all operation nodes
+ */
+class OperationNode : public ir::FunctionBaseNode {
+ public:
+ /*! \brief optional name of the operation */
+ std::string name;
+ /*! \brief optional tag of the operation */
+ std::string tag;
+ /*! \brief additional attributes of the operation*/
+ Map<std::string, ObjectRef> attrs;
+ /*! \return name of the operation */
+ const std::string& func_name() const final {
+ return name;
+ }
+ /*!
+ * \return The list of iteration variable at root
+ * \note root_iter_vars decides the shape of the outputs.
+ */
+ virtual Array<IterVar> root_iter_vars() const = 0;
+ /*!
+ * \brief Get data type. i-th output tensor.
+ * \param i The output index.
+ * \return type of i-th output.
+ */
+ virtual DataType output_dtype(size_t i) const = 0;
+ /*!
+ * \brief Get shape of i-th output tensor.
+ * \param i The output index.
+ * \return shape of i-th output.
+ */
+ virtual Array<PrimExpr> output_shape(size_t i) const = 0;
+ /*!
+ * \brief List all the input Tensors.
+ * \return List of input tensors.
+ */
+ virtual Array<Tensor> InputTensors() const = 0;
+ /*!
+ * \brief Replace the input of the operation by pattern specified by rmap.
+ *
+ * \param self The reference to self.
+ * \param rmap The replacement map.
+ * \return self if nothing is replaced, otherwise return replaced op.
+ */
+ virtual Operation ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
+ /*!
+ * \brief Propagate the bounds to inputs
+ * \param self The reference to self.
+ * \param analyzer The analyzer to be used in the function.
+ * \param dom_map the domain map of Variables(corresponds to root_iter_vars)
+ * \param out_dom_map The output domain.
+ * The function is only asked to fill the bounds for Tensors that
+ * is already in the out_dom_map
+ */
+ virtual void PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
+ /*!
+ * \brief Gather the bound from output tensor.
+ * Set the range of each root_iter_vars in the op to out_dom_map
+ *
+ * \param self The reference to self.
+ * \param tensor_dom Domain map of Tensor->access set of each dimension.
+ * \param out_dom_map The output domain map of each IterVar to be setted.
+ */
+ virtual void GatherBound(
+ const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
+ /*!
+ * \brief Build the Realize statement that realizes
+ * the op's output tensors.
+ * \param stage the op's stage.
+ * \param realize_map The realization domain map of the operators.
+ * \param body The body that is going to get
+ * \return A realization statement that wraps body.
+ */
+ virtual Stmt BuildRealize(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const = 0;
+ /*!
+ * \brief Build the statement that provide the output tensors.
+ * \param stage The schedule stage of the op.
+ * \param dom_map The domain map of all iteration domains.
+ * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
+ * \return A statement that add production and wraps consumer.
+ */
+ virtual Stmt BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const = 0;
+
+ static constexpr const char* _type_key = "Operation";
+
+ TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object);
+};
+
+/*!
+ * \brief A placeholder op represents an input placeholder.
+ */
+class PlaceholderOpNode : public OperationNode {
+ public:
+ /*! \brief The shape of the input */
+ Array<PrimExpr> shape;
+ /*! \brief The data type of the input. */
+ DataType dtype;
+ // override behavior.
+ int num_outputs() const final;
+ Array<IterVar> root_iter_vars() const final;
+ DataType output_dtype(size_t i) const final;
+ Array<PrimExpr> output_shape(size_t i) const final;
+ Array<Tensor> InputTensors() const final;
+ Operation ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ void GatherBound(
+ const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const final;
+ Stmt BuildRealize(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const final;
+ Stmt BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("tag", &tag);
+ v->Visit("attrs", &attrs);
+ v->Visit("shape", &shape);
+ v->Visit("dtype", &dtype);
+ }
+ static Operation make(std::string name,
+ Array<PrimExpr> shape,
+ DataType dtype);
+
+ static constexpr const char* _type_key = "PlaceholderOp";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
+};
+
+/*!
+ * \brief A Compute op that compute a tensor on certain domain.
+ * This is the base class for ComputeOp (operating on a scalar at a time) and
+ * TensorComputeOp (operating on a TensorSlice at a time)
+ */
+class TVM_DLL BaseComputeOpNode : public OperationNode {
+ public:
+ /*! \brief IterVar on each axis */
+ Array<IterVar> axis;
+ /*! \brief IterVar on each reduction axis, if the body is a Reduce */
+ Array<IterVar> reduce_axis;
+ // override functions
+ Array<IterVar> root_iter_vars() const final;
+ Array<PrimExpr> output_shape(size_t idx) const final;
+ void GatherBound(
+ const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const final;
+ Stmt BuildRealize(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const final;
+ virtual size_t num_schedulable_dims() const = 0;
+
+ static constexpr const char* _type_key = "BaseComputeOp";
+ TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
+};
+
+
+/*!
+ * \brief A Compute op that compute a tensor on certain domain.
+ */
+class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
+ public:
+ /*! \brief the compute expression */
+ Array<PrimExpr> body;
+ /*! \brief constructor */
+ ComputeOpNode() {}
+ // override functions
+ int num_outputs() const final;
+ DataType output_dtype(size_t i) const final;
+ Array<Tensor> InputTensors() const final;
+ Operation ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ Stmt BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
+ size_t num_schedulable_dims() const final;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("tag", &tag);
+ v->Visit("attrs", &attrs);
+ v->Visit("axis", &axis);
+ v->Visit("reduce_axis", &reduce_axis);
+ v->Visit("body", &body);
+ }
+ static Operation make(std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs,
+ Array<IterVar> axis,
+ Array<PrimExpr> body);
+
+ static constexpr const char* _type_key = "ComputeOp";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
+};
+
+/*!
+ * \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
+ */
+class TensorComputeOpNode : public BaseComputeOpNode {
+ public:
+ /*! \brief number of axes that can be scheduled */
+ int schedulable_ndim;
+ /*! \brief TensorIntrin used to compute */
+ TensorIntrin intrin;
+ /*! \brief input tensors of intrin */
+ Array<Tensor> inputs;
+ /*! \brief region of input tensors */
+ Array<Region> input_regions;
+ /*! \brief scalar expression inputs */
+ Array<PrimExpr> scalar_inputs;
+ /*! \brief constructor */
+ TensorComputeOpNode() {}
+ // override functions
+ int num_outputs() const final;
+ DataType output_dtype(size_t i) const final;
+ Array<Tensor> InputTensors() const final;
+ Operation ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ Stmt BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
+ size_t num_schedulable_dims() const final;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("tag", &tag);
+ v->Visit("axis", &axis);
+ v->Visit("reduce_axis", &reduce_axis);
+ v->Visit("schedulable_ndim", &schedulable_ndim);
+ v->Visit("intrin", &intrin);
+ v->Visit("inputs", &inputs);
+ v->Visit("input_regions", &input_regions);
+ v->Visit("scalar_inputs", &scalar_inputs);
+ }
+ static Operation make(std::string name,
+ std::string tag,
+ Array<IterVar> axis,
+ Array<IterVar> reduce_axis,
+ int schedulable_ndim,
+ TensorIntrin intrin,
+ Array<Tensor> tensors,
+ Array<Region> regions,
+ Array<PrimExpr> scalar_inputs);
+
+ static constexpr const char* _type_key = "TensorComputeOp";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
+};
+
+/*!
+ * \brief Symbolic scan.
+ */
+class ScanOpNode : public OperationNode {
+ public:
+ /*! \brief IterVar to scan over */
+ IterVar scan_axis;
+ /*! \brief the initialization tensors */
+ Array<Tensor> init;
+ /*! \brief the update function represented by tensor */
+ Array<Tensor> update;
+ /*! \brief The placeholder to refer as states in update. */
+ Array<Tensor> state_placeholder;
+ /*!
+ * \brief the inputs to the scan, these are optionally provided
+ * But they can be helpful to provide hints to speedup get of scan body.
+ */
+ Array<Tensor> inputs;
+ /*!
+ * \brief Spatial axis to indicate spatial dimension of each output.
+ * They corresponds to flattened spatial axis of the outputs.
+ *
+ * [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
+ * These are auxiliary data structure for storing result of bound inference.
+ * They do not corresponds to splittable iterations, thus the name comes
+ * with underscore.
+ */
+ Array<IterVar> spatial_axis_;
+ /*! \brief constructor */
+ ScanOpNode() {}
+ // override behavior.
+ int num_outputs() const final;
+ Array<IterVar> root_iter_vars() const final;
+ DataType output_dtype(size_t i) const final;
+ Array<PrimExpr> output_shape(size_t i) const final;
+ Array<Tensor> InputTensors() const final;
+ Operation ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ void GatherBound(
+ const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const final;
+ Stmt BuildRealize(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const final;
+ Stmt BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("tag", &tag);
+ v->Visit("attrs", &attrs);
+ v->Visit("scan_axis", &scan_axis);
+ v->Visit("init", &init);
+ v->Visit("update", &update);
+ v->Visit("state_placeholder", &state_placeholder);
+ v->Visit("inputs", &inputs);
+ v->Visit("spatial_axis_", &spatial_axis_);
+ }
+ static Operation make(std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs,
+ IterVar axis,
+ Array<Tensor> init,
+ Array<Tensor> update,
+ Array<Tensor> state_placeholder,
+ Array<Tensor> input);
+
+ static constexpr const char* _type_key = "ScanOp";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
+};
+
+/*!
+ * \brief External computation that cannot be splitted.
+ */
+class ExternOpNode : public OperationNode {
+ public:
+ /*! \brief The input tensors */
+ Array<Tensor> inputs;
+ /*! \brief Symbolic placeholder representation of inputs */
+ Array<Buffer> input_placeholders;
+ /*! \brief Symbolic placeholder representation of outputs */
+ Array<Buffer> output_placeholders;
+ /*! \brief the statement that generates the computation. */
+ Stmt body;
+
+ /*! \brief constructor */
+ ExternOpNode() {}
+ // override functions
+ int num_outputs() const final;
+ Array<IterVar> root_iter_vars() const final;
+ DataType output_dtype(size_t i) const final;
+ Array<PrimExpr> output_shape(size_t i) const final;
+ Array<Tensor> InputTensors() const final;
+ Operation ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ void GatherBound(
+ const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const final;
+ Stmt BuildRealize(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const final;
+ Stmt BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("tag", &tag);
+ v->Visit("attrs", &attrs);
+ v->Visit("inputs", &inputs);
+ v->Visit("input_placeholders", &input_placeholders);
+ v->Visit("output_placeholders", &output_placeholders);
+ v->Visit("body", &body);
+ }
+ TVM_DLL static Operation make(std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs,
+ Array<Tensor> inputs,
+ Array<Buffer> input_placeholders,
+ Array<Buffer> output_placeholders,
+ Stmt body);
+
+ static constexpr const char* _type_key = "ExternOp";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
+};
+
+/*!
+ * \brief A computation operator that generated by hybrid script.
+ */
+class HybridOpNode : public OperationNode {
+ public:
+ /*! \brief The input tensors */
+ Array<Tensor> inputs;
+ /*! \brief Symbolic placeholder representation of outputs */
+ Array<Tensor> outputs;
+ /*! \brief The axis of iterations */
+ Array<IterVar> axis;
+ /*! \brief the statement that generates the computation. This is
+ * slightly different from the body in ExternOpNode. All the output
+ * tensors keep its own name specified by users in the script.
+ * However, when compilation, these tensors will be placed by those
+ * actual output tensors. */
+ Stmt body;
+
+ /*! \brief constructor */
+ HybridOpNode() {}
+ // override functions
+ int num_outputs() const final;
+ Array<IterVar> root_iter_vars() const final;
+ DataType output_dtype(size_t i) const final;
+ Array<PrimExpr> output_shape(size_t i) const final;
+ Array<Tensor> InputTensors() const final;
+ Operation ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const final;
+ void PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+ void GatherBound(
+ const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const final;
+ Stmt BuildRealize(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const final;
+ Stmt BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const final;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("tag", &tag);
+ v->Visit("attrs", &attrs);
+ v->Visit("inputs", &inputs);
+ v->Visit("outputs", &outputs);
+ v->Visit("axis", &axis);
+ v->Visit("body", &body);
+ }
+ TVM_DLL static Operation make(std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs,
+ Array<Tensor> inputs,
+ Array<Tensor> outputs,
+ Stmt body);
+
+ static constexpr const char* _type_key = "HybridOp";
+ TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
+};
+
+/*! \brief The compute function to specify the input source of a Tensor */
+using FCompute = std::function<PrimExpr (const Array<Var>& i)>;
+
+/*! \brief The compute function to specify the inputs source of Tensors */
+using FBatchCompute = std::function<Array<PrimExpr> (const Array<Var>& i)>;
+
+/*!
+ * \brief create a place holder tensor.
+ * \param shape The shape of the tensor.
+ * \param dtype the data type of the tensor.
+ * \param name The name of the Tensor.
+ */
+TVM_DLL Tensor placeholder(Array<PrimExpr> shape,
+ DataType dtype = DataType::Float(32),
+ std::string name = "placeholder");
+
+/*!
+ * \brief Construct a new tensor by computing over shape,
+ * using the computation rule: result_tensor[axis] = fcompute(axis)
+ * \param shape Shape of the tensor.
+ * \param fcompute The compute function to create the tensor.
+ * \param name The optional name of the tensor.
+ * \param tag The optional tag of the tensor.
+ * \param attrs Optional additional attributes of the compute.
+ */
+TVM_DLL Tensor compute(Array<PrimExpr> shape,
+ FCompute fcompute,
+ std::string name = "tensor",
+ std::string tag = "",
+ Map<std::string, ObjectRef> attrs = {});
+
+/*!
+ * \brief Construct a new tensor by computing over shape,
+ * using the computation rule: result_tensor[axis] = fcompute(axis)
+ * \param shape Shape of the tensor.
+ * \param fcompute The compute function to create the tensors.
+ * \param name The optional name of the tensor.
+ * \param tag The optional tag of the tensor.
+ * \param attrs Optional additional attributes of the compute.
+ */
+TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape,
+ FBatchCompute fcompute,
+ std::string name = "tensor",
+ std::string tag = "",
+ Map<std::string, ObjectRef> attrs = {});
+
+/*!
+ * \brief Construct new tensors by scan.
+ *
+ * \param init The intialize tensor of first K steps.
+ * \param update The update tensor indicated the updated result after each timestamp.
+ * \param state_placeholder The placeholder for the states.
+ * \param inputs The inputs to the scan body, this is optional,
+ * but recommended to provide concrete information about scan body.
+ * \param name The optional name of the tensor.
+ * \param tag The optional tag of the tensor.
+ * \param attrs Optional additional attributes of the compute.
+ */
+TVM_DLL Array<Tensor> scan(Array<Tensor> init,
+ Array<Tensor> update,
+ Array<Tensor> state_placeholder,
+ Array<Tensor> inputs = Array<Tensor>(),
+ std::string name = "scan",
+ std::string tag = "",
+ Map<std::string, ObjectRef> attrs = {});
+
+// same as compute, specialized for different fcompute function
+inline Tensor compute(Array<PrimExpr> shape,
+ std::function<PrimExpr(Var)> f,
+ std::string name = "tensor",
+ std::string tag = "",
+ Map<std::string, ObjectRef> attrs = {}) {
+ FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
+ return compute(shape, fc, name, tag, attrs);
+}
+inline Tensor compute(Array<PrimExpr> shape,
+ std::function<PrimExpr(Var, Var)> f,
+ std::string name = "tensor",
+ std::string tag = "",
+ Map<std::string, ObjectRef> attrs = {}) {
+ FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
+ return compute(shape, fc, name, tag, attrs);
+}
+inline Tensor compute(Array<PrimExpr> shape,
+ std::function<PrimExpr(Var, Var, Var)> f,
+ std::string name = "tensor",
+ std::string tag = "",
+ Map<std::string, ObjectRef> attrs = {}) {
+ FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
+ return compute(shape, fc, name, tag, attrs);
+}
+inline Tensor compute(Array<PrimExpr> shape,
+ std::function<PrimExpr(Var, Var, Var, Var)> f,
+ std::string name = "tensor",
+ std::string tag = "",
+ Map<std::string, ObjectRef> attrs = {}) {
+ FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
+ return compute(shape, fc, name, tag, attrs);
+}
+
+// inline function.
+inline const OperationNode* Operation::operator->() const {
+ return static_cast<const OperationNode*>(get());
+}
+} // namespace top
+} // namespace tvm
+#endif // TVM_TOP_OPERATION_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/top/schedule.h
+ * \brief Define a schedule.
+ */
+// Acknowledgement: Many schedule primitives originate from Halide and Loopy.
+#ifndef TVM_TOP_SCHEDULE_H_
+#define TVM_TOP_SCHEDULE_H_
+
+#include <tvm/expr.h>
+#include <tvm/top/tensor.h>
+#include <tvm/top/tensor_intrin.h>
+
+
+#include <string>
+#include <unordered_map>
+
+
+namespace tvm {
+namespace top {
+// Node container for Stage
+class StageNode;
+// Node container for Schedule
+class ScheduleNode;
+// Node container for IterVarRelation
+class IterVarRelationNode;
+// Attribute of itervar.
+class IterVarAttrNode;
+
+/*! \brief the attachment type */
+enum AttachType : int {
+ kGroupRoot = 1,
+ kInline = 2,
+ kInlinedAlready = 3,
+ kScope = 4,
+ kScanUpdate = 5
+};
+
+/*! \brief Stage, contains scheduling for a stage of computation. */
+class Stage : public ObjectRef {
+ public:
+ Stage() {}
+ explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief create a new schedule for op.
+ * \param op The operator in the schedule
+ */
+ explicit Stage(Operation op);
+ /*!
+ * \brief access the internal node container
+ * \return the pointer to the internal node container
+ */
+ inline const StageNode* operator->() const;
+ /*!
+ * \brief access the internal node container
+ * \return the pointer to the internal node container
+ */
+ inline StageNode* operator->();
+ /*!
+ * \brief set the memory scope of the stage
+ * \param scope The memory scope.
+ */
+ TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*)
+ /*!
+ * \brief specify the schedule to be computed at the parent schedule's scope.
+ * \param parent The parent schedule.
+ * \param scope The iteration point to carry the schedule.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
+ /*!
+ * \brief Compute the function inline.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& compute_inline(); // NOLINT(*)
+ /*!
+ * \brief Compute the function at group root.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& compute_root(); // NOLINT(*)
+ /*!
+ * \brief Bind the IterVar to thread index.
+ *
+ * \param ivar The IterVar to be bound.
+ * \param thread_ivar The thread axis to be bound.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
+ /*!
+ * \brief Set the predicate to determine whether a store to the array should be performed.
+ * Use this when there are multiple threads performing the same store and we only
+ * need one of them to do the store.
+ *
+ * \note This is a dangerous scheduling primitive that can change behavior of program.
+ * Only do when we are certain that thare are duplicated stores.
+ * \param predicate The condition to be checked.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
+ /*!
+ * \brief Specify environment threads that launched around the group's scope.
+ * This can only be used in group stage.
+ * \param threads The threads to be launched around the scope.
+ * \note Each thread can only appear in one env_threads.
+ * This is a beta feature.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& env_threads(Array<IterVar> threads);
+ /*!
+ * \brief Split the parent by factor, generate
+ * \param parent The parent iteration domain.
+ * \param factor The split factor of the loop.
+ * \param p_outer The result outer domain
+ * \param p_inner The result inner domain.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
+ /*!
+ * \brief Split the iteration with given number of parts.
+ *
+ * \param parent The parent domain.
+ * \param nparts The number of parts in the outer domain.
+ * \param p_outer The result outer domain.
+ * \param p_inner The result inner domain.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
+ /*!
+ * \brief Fuse the inner outer domain to the target
+ * \param outer The outer domain to be fused.
+ * \param inner The inner domain to be fused
+ * \param p_target The result target domain.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
+ /*!
+ * \brief Fuse all the axes together into a single axis.
+ *
+ * \param axes All the axes to be fused.
+ * \param p_target The result target domain.
+ *
+ * \note axes can be an empty array,
+ * in that case, a singleton IterVar is created and
+ * inserted to the outermost loop.
+ * The fuse of empty array is used to support zero-dimension tensors.
+ *
+ * \return reference to self.
+ */
+ TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
+ /*!
+ * \brief Reorder the iteration
+ * \param order The order of iteration variable.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
+ /*!
+ * \brief Perform tiling on two dimensions
+ * The final loop order from outmost to inner most are
+ * [x_outer, y_outer, x_inner, y_inner]
+ *
+ * \param x_parent The original x dimension
+ * \param y_parent The original y dimension
+ * \param x_factor The stride factor on x axis
+ * \param y_factor The stride factor on y axis
+ * \param p_x_outer Outer axis of x dimension
+ * \param p_y_outer Outer axis of y dimension
+ * \param p_x_inner Inner axis of x dimension
+ * \param p_y_inner Inner axis of y dimension
+ * \return reference to self.
+ */
+ TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
+ PrimExpr x_factor, PrimExpr y_factor,
+ IterVar* p_x_outer, IterVar* p_y_outer,
+ IterVar* p_x_inner, IterVar* p_y_inner);
+ /*!
+ * \brief Vectorize iteration.
+ * \param var The axis to be vectorized.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
+ /*!
+ * \brief Replace computation of the current stage by tensor intrinsic f.
+ * \param var The axis marks beginning of tensorization.
+ * Every operations inside the axis(include axis itself is tensorized).
+ * \param f The Tensor compute intrinsics.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
+ /*!
+ * \brief Unroll iteration.
+ * \param var The axis to be unrolled.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
+ /*!
+ * \brief Parallelize iteration.
+ * \param var The axis to be parallelized.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
+ /*!
+ * \brief Annotate the iteration with pragma
+ *
+ * \param var The axis to be parallelized.
+ * \param pragma_type The pragma type.
+ * \param pragma_value The pragma value
+ *
+ * \return reference to self.
+ */
+ TVM_DLL Stage& pragma(IterVar var,
+ const std::string& pragma_type,
+ const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*)
+ /*!
+ * \brief Fetch data in advance.
+ * \param domain the tensor to be prefetched
+ * \param var the iteration point at which to apply prefetching
+ * \param offset the number of iterations be to fetched in advance
+ * \return reference to self
+ */
+ TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*)
+ /*!
+ * \brief Set alignment requirement for specific dimension.
+ *
+ * Such that stride[axis] == k * factor + offset for some k.
+ *
+ * \param axis The dimension to be specified for alignment.
+ * \param factor The factor multiple of alignment
+ * \param offset The required offset factor.
+ * \return reference to self
+ */
+ TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
+ /*!
+ * \brief Compute current stage with double buffering.
+ * \return reference to self.
+ */
+ TVM_DLL Stage& double_buffer(); // NOLINT(*)
+ /*!
+ * \brief Schedule for OpenGL fragment shader.
+ * \return reference to self.
+ */
+ Stage& opengl(); // NOLINT(*)
+ /*!
+ * \brief whether the stage has been scheduled.
+ * \return whether the stage has been scheduled.
+ */
+ bool is_scheduled() const;
+ /*!
+ * \brief Get attachment spec of current stage.
+ * If the stage compute at Group root, this function
+ * will traverse the group function to get the
+ * final spec from the group.
+ * \return A stage representing the attach spec of the group.
+ */
+ Stage GetAttachSpec() const;
+ // declare container type
+ using ContainerType = StageNode;
+};
+
+/*!
+ * \brief Global schedule container
+ * For operations and all the operations they depend on.
+ * The schedule per Operation is named as stage.
+ */
+class Schedule : public ObjectRef {
+ public:
+ Schedule() {}
+ explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief Get a copy of current schedule.
+ * \return The copied schedule.
+ */
+ Schedule copy() const;
+ /*!
+ * \brief Get the stage corresponds to the op
+ * \param op The operation.
+ */
+ TVM_DLL Stage operator[](const Operation& op);
+ /*!
+ * \brief Short hand for getting the stage of tensor's operation.
+ * \param tensor The tensor
+ * \return The stage corresponding to the tensor's op
+ */
+ TVM_DLL Stage operator[](const Tensor& tensor) {
+ return this->operator[](tensor->op);
+ }
+ /*!
+ * \brief Create a new stage group for all intermediate
+ * operations between inputs and outputs.
+ *
+ * \param outputs The output boundary of the group.
+ * \param inputs The input boundary of the group.
+ * \param include_inputs Whether include inputs if they are reachable from outputs.
+ * \return The new grouped stage.
+ */
+ TVM_DLL Stage create_group(const Array<Tensor>& outputs,
+ const Array<Tensor>& inputs,
+ bool include_inputs = false);
+ /*!
+ * \brief create a cache read of original tensor for readers.
+ * This will mutate the body of the readers.
+ * A new stage will be created for the tensor.
+ * \param tensor The tensor cached.
+ * \param scope The scope of the cache.
+ * \param readers The readers to redirect to the tensor.
+ * \return The created tensor.
+ */
+ TVM_DLL Tensor cache_read(const Tensor& tensor,
+ const std::string& scope,
+ const Array<Operation>& readers);
+ /*!
+ * \brief Create a cache write tensor for producing tensor.
+ * The the tensor will take over body of original tensor op.
+ *
+ * This function can be used to do data layout transformation.
+ * If there is a split/fuse/reorder on the data parallel axis of tensor
+ * before cache_write is called. The intermediate cache stores
+ * the data in the layout as the iteration order of leave axis.
+ * The data will be transformed back to the original layout in the original tensor.
+ * User can further call compute_inline to inline the original layout and keep
+ * the data stored in the transformed layout.
+ *
+ * \param tensor The tensors to be produced.
+ * \param scope The scope of the storage.
+ * \return The created tensor.
+ */
+ TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
+ /*!
+ * \brief Create a cache write tensor for producing tensor.
+ * The the tensor will take over body of original tensor op.
+ *
+ * This function can be used to do data layout transformation.
+ * If there is a split/fuse/reorder on the data parallel axis of tensor
+ * before cache_write is called. The intermediate cache stores
+ * the data in the layout as the iteration order of leave axis.
+ * The data will be transformed back to the original layout in the original tensor.
+ * User can further call compute_inline to inline the original layout and keep
+ * the data stored in the transformed layout.
+ *
+ * \param tensor The tensor to be produced.
+ * \param scope The scope of the storage.
+ * \return The created tensor.
+ */
+ TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
+ /*!
+ * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
+ * This will create a new stage that generated the new tensor with axis
+ * as the first dimension. The tensor's body will be rewritten as a reduction
+ * over the factored tensor.
+ *
+ * P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
+ *
+ * \param tensor The tensor to be factored.
+ * \param axis The reduction axis in tensor's schedule to be factored.
+ * \param factor_axis The position where the new axis is placed.
+ * \return The created factored tensors.
+ */
+ TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
+ const IterVar& axis,
+ int factor_axis = 0);
+ /*!
+ * \brief Normalize the schedule.
+ * This is needed before bound inference.
+ * Insert necessary RebaseNode to make sure all leaf_iter_vars
+ * are in form [0, extent)
+ *
+ * \return A normalized schedule, can be same as current one.
+ */
+ Schedule normalize();
+ /*!
+ * \brief access the internal node container
+ * \return the pointer to the internal node container
+ */
+ inline const ScheduleNode* operator->() const;
+ /*!
+ * \brief access the internal node container
+ * \return the pointer to the internal node container
+ */
+ inline ScheduleNode* operator->();
+ // declare container type
+ using ContainerType = ScheduleNode;
+};
+
+/*!
+ * \brief The schedule relation between IterVars
+ * can be Split, Fuse.
+ */
+class IterVarRelation : public ObjectRef {
+ public:
+ IterVarRelation() {}
+ explicit IterVarRelation(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief access the internal node container
+ * \return the pointer to the internal node container
+ */
+ inline const IterVarRelationNode* operator->() const;
+};
+
+/*!
+ * \brief Additional scheduable attributes about IterVar.
+ */
+class IterVarAttr : public ObjectRef {
+ public:
+ IterVarAttr() {}
+ explicit IterVarAttr(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief access the internal node container
+ * \return the pointer to the internal node container
+ */
+ inline const IterVarAttrNode* operator->() const;
+};
+
+/*!
+ * \brief represents a stage.
+ *
+ * relations form a Directed acylic hypergraph in bipartite manner.
+ * With each node is represented by a IterVar,
+ * and each hyper-edge is represented by a IterVarRelation.
+ * The relations connects the IterVars in the graph.
+ *
+ * Besides typical stage that corresponds to operations.
+ * There is also group stage, which groups stages together.
+ * Each stage's group(given by group) represent an constraint,
+ * the stage can only be attached to stages within the group.
+ *
+ * The group stage node can be attached to IterVars as in normal stage.
+ */
+class StageNode : public Object {
+ public:
+ /*!
+ * \brief The operation of stage, can be different from original op.
+ * If it is null, then this stage is a group stage.
+ */
+ Operation op;
+ /*!
+ * \brief The original operator.
+ * The op field can change during schedule to alternate the dataflow,
+ * while origin_op remains fixed.
+ */
+ Operation origin_op;
+ /*! \brief All the nodes in the iter var */
+ Array<IterVar> all_iter_vars;
+ /*! \brief The current active leaf iter vars in the stage. */
+ Array<IterVar> leaf_iter_vars;
+ /*!
+ * \brief Specify threads to be launched at the stage.
+ * This is only valid for composite ops such as Scan.
+ * \note Experimental primitive: used for thread persistence.
+ */
+ Array<IterVar> env_threads;
+ /*!
+ * \brief The predicate under which store can happen
+ * Use this when there can be duplicated threads doing the same store.
+ * \note Experimental primitive: used by cross thread-reduction.
+ */
+ PrimExpr store_predicate;
+ /*! \brief The relation bwteen of IterVars */
+ Array<IterVarRelation> relations;
+ /*! \brief additional attributes about iter var. */
+ Map<IterVar, IterVarAttr> iter_var_attrs;
+ /*! \brief The attachment type of the schedule */
+ AttachType attach_type{kGroupRoot};
+ /*! \brief The attach point of this schedule. */
+ IterVar attach_ivar;
+ /*! \brief The stage this node attaches to */
+ Stage attach_stage;
+ /*! \brief The thread storage scope level of the stage */
+ std::string scope;
+ /*! \brief Whether this is an output stage */
+ bool is_output{false};
+ /*! \brief Whether this is an OpenGL stage */
+ bool is_opengl{false};
+ /*! \brief Whether apply double buffer optimization to this stage */
+ bool double_buffer{false};
+ /*!
+ * \brief The parent group of the current stage.
+ * The stage cannot be assigned to stages outside the group.
+ */
+ Stage group;
+ /*! \brief Number of direct child stages, only used for group stage.*/
+ int num_child_stages{0};
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("op", &op);
+ v->Visit("origin_op", &origin_op);
+ v->Visit("all_iter_vars", &all_iter_vars);
+ v->Visit("leaf_iter_vars", &leaf_iter_vars);
+ v->Visit("env_threads", &env_threads);
+ v->Visit("relations", &relations);
+ v->Visit("iter_var_attrs", &iter_var_attrs);
+ v->Visit("attach_type", &attach_type);
+ v->Visit("attach_ivar", &attach_ivar);
+ v->Visit("attach_stage", &attach_stage);
+ v->Visit("scope", &scope);
+ v->Visit("is_output", &is_output);
+ v->Visit("is_opengl", &is_opengl);
+ v->Visit("double_buffer", &double_buffer);
+ v->Visit("group", &group);
+ v->Visit("num_child_stages", &num_child_stages);
+ }
+
+ static constexpr const char* _type_key = "Stage";
+ TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
+};
+
+/*! \brief node container for schedule */
+class ScheduleNode : public Object {
+ public:
+ /*! \brief The output operations in original data flow graph */
+ Array<Operation> outputs;
+ /*!
+ * \brief list of all stages for ops.
+ * The stages are sorted in dependency order.
+ */
+ Array<Stage> stages;
+ /*!
+ * \brief List of all stage groups.
+ */
+ Array<Stage> groups;
+ /*! \brief map of original operation to the stages */
+ Map<Operation, Stage> stage_map;
+ /*!
+ * \brief Internal stage map to map internal ops to stages.
+ * This is created on demand and can be invalidated.
+ */
+ std::unordered_map<const Object*, Stage> op2stage_cache_;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("outputs", &outputs);
+ v->Visit("stages", &stages);
+ v->Visit("groups", &groups);
+ v->Visit("stage_map", &stage_map);
+ }
+
+ /*! \brief Initialize temp cache. */
+ void InitCache();
+ /*! \brief Invalidate temp cache. */
+ void InvalidateCache();
+
+ /*!
+ * \brief Check if the schedule contains an Operation.
+ * \param op The candidate Operation.
+ * \return true if the schedule has the Operation. Otherwise, false.
+ */
+ TVM_DLL bool Contain(const Operation& op) const;
+
+ /*!
+ * \brief Check if the schedule contains a Tensor.
+ * \param tensor The candidate tensor.
+ * \return true if the schedule has the tensor. Otherwise, false.
+ */
+ TVM_DLL bool Contain(const Tensor& tensor) const {
+ return Contain(tensor->op);
+ }
+
+ /*!
+ * \brief Create a schedule for array of ops(and their dependencies).
+ * \param ops The ops to be scheduled.
+ * \return sch The created Schedule.
+ */
+ TVM_DLL static Schedule make(Array<Operation> ops);
+
+ static constexpr const char* _type_key = "Schedule";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
+};
+
+/*!
+ * \brief Create a schedule for array of ops(and their dependencies).
+ * \param ops The ops to be scheduled.
+ * \return sch The created Schedule.
+ */
+inline Schedule create_schedule(Array<Operation> ops) {
+ return ScheduleNode::make(ops);
+}
+
+/*! \brief node container for IterVar attr */
+class IterVarAttrNode : public Object {
+ public:
+ /*! \brief The iteration type. */
+ IterVarType iter_type{kDataPar};
+ /*! \brief The thread this iter Var binds, can be null */
+ IterVar bind_thread;
+ /*! \brief List of tensor to be prefetched in this loop */
+ Array<Tensor> prefetch_data;
+ /*! \brief The offset used in each prefetch */
+ Array<PrimExpr> prefetch_offset;
+ /*!
+ * \brief Tensor intrinsic used in tensorization,
+ * when the axis is marked as Tensorized
+ */
+ TensorIntrin tensor_intrin;
+ /*! \brief Alignment factor of buffer dimension */
+ int dim_align_factor{0};
+ /*! \brief Alignment offset of buffer dimension */
+ int dim_align_offset{0};
+ /*!
+ * \brief Additional pragma keys, array of StringImm
+ */
+ Array<PrimExpr> pragma_keys;
+ /*!
+ * \brief Additional values of pragma, if any
+ */
+ Array<PrimExpr> pragma_values;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("iter_type", &iter_type);
+ v->Visit("bind_thread", &bind_thread);
+ v->Visit("prefetch_data", &prefetch_data);
+ v->Visit("prefetch_offset", &prefetch_offset);
+ v->Visit("tensor_intrin", &tensor_intrin);
+ v->Visit("dim_align_factor", &dim_align_factor);
+ v->Visit("dim_align_offset", &dim_align_offset);
+ v->Visit("pragma_keys", &pragma_keys);
+ v->Visit("pragma_values", &pragma_values);
+ }
+
+ static constexpr const char* _type_key = "IterVarAttr";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object);
+};
+
+/*! \brief base node of iteration var */
+class IterVarRelationNode : public Object {
+ public:
+ static constexpr const char* _type_key = "IterVarRelation";
+ TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object);
+};
+
+/*!
+ * \brief Split the parent domain into product of
+ * outer and iter.
+ */
+class SplitNode : public IterVarRelationNode {
+ public:
+ /*! \brief The parent domain */
+ IterVar parent;
+ /*! \brief The outer domain */
+ IterVar outer;
+ /*! \brief The inner domain */
+ IterVar inner;
+ /*! \brief The split factor */
+ PrimExpr factor;
+ /*! \brief Number of parts, only factor or nparts can be given */
+ PrimExpr nparts;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("parent", &parent);
+ v->Visit("outer", &outer);
+ v->Visit("inner", &inner);
+ v->Visit("factor", &factor);
+ v->Visit("nparts", &nparts);
+ }
+
+ static IterVarRelation make(IterVar parent,
+ IterVar outer,
+ IterVar inner,
+ PrimExpr factor,
+ PrimExpr nparts);
+
+ static constexpr const char* _type_key = "Split";
+ TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
+};
+
+/*!
+ * \brief Fuse two domains into one domain.
+ */
+class FuseNode : public IterVarRelationNode {
+ public:
+ /*! \brief The outer domain */
+ IterVar outer;
+ /*! \brief The inner domain */
+ IterVar inner;
+ /*! \brief The target domain */
+ IterVar fused;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("outer", &outer);
+ v->Visit("inner", &inner);
+ v->Visit("fused", &fused);
+ }
+
+ static IterVarRelation make(
+ IterVar outer, IterVar inner, IterVar fused);
+
+ static constexpr const char* _type_key = "Fuse";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
+};
+
+/*!
+ * \brief Rebase the iteration to make min to be 0.
+ * This is useful to normalize the Schedule
+ * to make every leaf variable's min to be 0.
+ */
+class RebaseNode : public IterVarRelationNode {
+ public:
+ /*! \brief The parent domain */
+ IterVar parent;
+ /*! \brief The inner domain */
+ IterVar rebased;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("parent", &parent);
+ v->Visit("rebased", &rebased);
+ }
+
+ static IterVarRelation make(IterVar parent, IterVar rebased);
+
+ static constexpr const char* _type_key = "Rebase";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
+};
+
+
+/*!
+ * \brief Singleton iterator [0, 1)
+ */
+class SingletonNode : public IterVarRelationNode {
+ public:
+ /*! \brief The singleton iterator */
+ IterVar iter;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("iter", &iter);
+ }
+
+ static IterVarRelation make(IterVar iter);
+
+ static constexpr const char* _type_key = "Singleton";
+ TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
+};
+
+
+// implementations
+inline const StageNode* Stage::operator->() const {
+ return static_cast<const StageNode*>(get());
+}
+inline StageNode* Stage::operator->() {
+ return static_cast<StageNode*>(get_mutable());
+}
+
+inline const ScheduleNode* Schedule::operator->() const {
+ return static_cast<const ScheduleNode*>(get());
+}
+inline ScheduleNode* Schedule::operator->() {
+ return static_cast<ScheduleNode*>(get_mutable());
+}
+
+inline const IterVarRelationNode* IterVarRelation::operator->() const {
+ return static_cast<const IterVarRelationNode*>(get());
+}
+
+inline const IterVarAttrNode* IterVarAttr::operator->() const {
+ return static_cast<const IterVarAttrNode*>(get());
+}
+} // namespace top
+} // namespace tvm
+#endif // TVM_TOP_SCHEDULE_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/top/schedule_pass.h
+ * \brief Collection of Schedule pass functions.
+ *
+ * These passes works on the schedule hyper-graph
+ * and infers information such as bounds, check conditions
+ * read/write dependencies between the IterVar
+ */
+#ifndef TVM_TOP_SCHEDULE_PASS_H_
+#define TVM_TOP_SCHEDULE_PASS_H_
+
+#include <tvm/top/schedule.h>
+
+namespace tvm {
+namespace top {
+
+/*!
+ * \brief Infer the bound of all iteration variables relates to the schedule.
+ *
+ * \param sch The root schedule to infer all the bounds.
+ * \return the result bound of the iteration Variable
+ */
+Map<IterVar, Range> InferBound(const Schedule& sch);
+
+/*!
+ * \brief Schedule s' dependent operations.
+ *
+ * \param s The schedule to be realized
+ * \param dom_map The domain of each iter vars.
+ * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 during lowering.
+ * This is a debug feature for dataflow/axis analysis.
+ * Note: If this is true, The lowered IR may be incorrect,
+ * because we will also delete the init part of reduction
+ * \return the result Stmt
+ */
+Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
+
+/*!
+ * \brief To automatically inline the element-wise operations.
+ *
+ * \param sch The schedule to be inlined.
+ */
+void AutoInlineElemWise(Schedule sch);
+
+/*!
+ * \brief To automatically inline operations with injective writes
+ * (i.e. writes without reduction or sequential loops). Note
+ * that in this case, guarantees about contiguity, transpose, stride,
+ * alignemnt and memory footprint in general do not hold.
+ *
+ * \param sch The schedule to be inlined.
+ */
+TVM_DLL void AutoInlineInjective(Schedule sch);
+
+} // namespace top
+} // namespace tvm
+#endif // TVM_TOP_SCHEDULE_PASS_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/top/tensor.h
+ * \brief Dataflow tensor object
+ */
+#ifndef TVM_TOP_TENSOR_H_
+#define TVM_TOP_TENSOR_H_
+
+#include <tvm/node/container.h>
+#include <tvm/arith/bound.h>
+#include <tvm/expr.h>
+#include <tvm/expr_operator.h>
+
+#include <string>
+#include <vector>
+#include <utility>
+#include <type_traits>
+
+
+
+namespace tvm {
+namespace top {
+
+// Internal node container of Tensor
+class TensorNode;
+// internal node container for Operation
+class OperationNode;
+
+/*!
+ * \brief Tensor structure representing a possible input,
+ * or intermediate computation result.
+ */
+class Tensor : public ObjectRef {
+ public:
+ /*! \brief default constructor, used internally */
+ Tensor() {}
+ explicit Tensor(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief access the internal node container
+ * \return the pointer to the internal node container
+ */
+ inline const TensorNode* operator->() const;
+ /*!
+ * \brief check if two tensors equals each other.
+ * \param other tensor to be checked.
+ * \return whether the two tensors equals each other.
+ */
+ inline bool operator==(const Tensor& other) const;
+ /*!
+ * \brief check if two tensors are different.
+ * \param other tensor to be checked.
+ * \return whether the two tensors are different.
+ */
+ inline bool operator!=(const Tensor& other) const;
+ /*! \return The dimension of the tensor */
+ inline size_t ndim() const;
+ /*!
+ * \brief Take elements from the tensor
+ * \param args The indices
+ * \return the result expression representing tensor read.
+ */
+ template<typename... Args>
+ inline PrimExpr operator()(Args&& ...args) const {
+ Array<PrimExpr> indices{std::forward<Args>(args)...};
+ return operator()(indices);
+ }
+ /*!
+ * \brief Take elements from the tensor
+ * \param indices the indices.
+ * \return the result expression representing tensor read.
+ */
+ TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const;
+ /*!
+ * \brief Take elements from the tensor
+ * \param indices the indices.
+ * \return the result expression representing tensor read.
+ */
+ TVM_DLL PrimExpr operator()(Array<Var> indices) const;
+ /*!
+ * \brief data structure to represent a slice that fixes first k coordinates.
+ * This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
+ */
+ class Slice {
+ public:
+ // construct via tensor and indices
+ Slice(const Tensor& tensor, std::vector<PrimExpr> indices)
+ : tensor_(tensor), indices_(indices) {}
+ /*!
+ * \brief get i-th slice from the current slice.
+ * \param i the index of the coordinate
+ * \return the subsequent slice.
+ */
+ inline Slice operator[](PrimExpr i) {
+ std::vector<PrimExpr> other = indices_;
+ other.emplace_back(i);
+ return Slice(tensor_, other);
+ }
+ /*!
+ * \brief Convert slice to expression.
+ * This is only valid when all the coordinates are fully specified.
+ * \return the corresponding expression of this slice.
+ */
+ inline operator PrimExpr() const {
+ return tensor_(indices_);
+ }
+
+ private:
+ const Tensor& tensor_;
+ std::vector<PrimExpr> indices_;
+ };
+ /*!
+ * \brief get i-th slice from the current Tensor.
+ * \param i the index of the coordinate
+ * \return the subsequent slice.
+ */
+ inline Slice operator[](PrimExpr i) const {
+ return Slice(*this, {i});
+ }
+ /*! \brief specify container node */
+ using ContainerType = TensorNode;
+};
+
+/*! \brief Operation that produces tensors */
+class Operation : public ir::FunctionRef {
+ public:
+ /*! \brief default constructor */
+ Operation() {}
+ explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
+ /*!
+ * \brief access the internal node container
+ * \return the pointer to the internal node container
+ */
+ inline const OperationNode* operator->() const;
+ /*!
+ * \brief get the i-th output of the operation.
+ * \param i the output index.
+ * \return The i-th output.
+ */
+ TVM_DLL Tensor output(size_t i) const;
+ /*! \brief specify container node */
+ using ContainerType = OperationNode;
+};
+
+/*! \brief Node to represent a tensor */
+class TensorNode : public Object {
+ public:
+ /*! \brief The shape of the tensor */
+ Array<PrimExpr> shape;
+ /*! \brief data type in the content of the tensor */
+ DataType dtype;
+ /*! \brief the source operation, can be None */
+ Operation op;
+ /*! \brief the output index from source operation */
+ int value_index{0};
+ /*! \brief constructor */
+ TensorNode() {}
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("shape", &shape);
+ v->Visit("dtype", &dtype);
+ v->Visit("op", &op);
+ v->Visit("value_index", &value_index);
+ }
+ TVM_DLL static Tensor make(Array<PrimExpr> shape,
+ DataType dtype,
+ Operation op,
+ int value_index);
+
+ static constexpr const char* _type_key = "Tensor";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
+};
+
+
+// Implementations of inline functions
+inline const TensorNode* Tensor::operator->() const {
+ return static_cast<const TensorNode*>(get());
+}
+
+inline size_t Tensor::ndim() const {
+ return (*this)->shape.size();
+}
+
+inline bool Tensor::operator==(const Tensor& other) const {
+ if (get() == other.get()) return true;
+ if (get() == nullptr || other.get() == nullptr) return false;
+ if ((*this)->op.defined() || other->op.defined()) {
+ return (*this)->op == other->op &&
+ (*this)->value_index == other->value_index;
+ } else {
+ return false;
+ }
+}
+
+inline bool Tensor::operator!=(const Tensor& other) const {
+ return !(*this == other);
+}
+
+// macro to turn every operation of slice to expression
+#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
+ inline PrimExpr operator Op (const Tensor::Slice& a) { \
+ return Op a.operator PrimExpr() ; \
+ } \
+
+#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
+ template<typename T> \
+ inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \
+ return a.operator PrimExpr() Op b; \
+ } \
+ template<typename T> \
+ inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \
+ return a Op b.operator PrimExpr(); \
+ } \
+ inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
+ return a.operator PrimExpr() Op b.operator PrimExpr(); \
+ }
+
+DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
+DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(!=);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(&&);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(||);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(>>);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(<<);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(>); // NOLINT(*)
+DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*)
+
+} // namespace top
+} // namespace tvm
+
+namespace std {
+template <>
+struct hash<::tvm::top::Operation> : public ::tvm::ObjectHash {
+};
+
+template <>
+struct hash<::tvm::top::Tensor> {
+ std::size_t operator()(const ::tvm::top::Tensor& k) const {
+ ::tvm::ObjectHash hasher;
+ if (k.defined() && k->op.defined()) {
+ return hasher(k->op);
+ } else{
+ return hasher(k);
+ }
+ }
+};
+} // namespace std
+#endif // TVM_TOP_TENSOR_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/top/tensor_intrin.h
+ * \brief Tensor intrinsic operations.
+ */
+#ifndef TVM_TOP_TENSOR_INTRIN_H_
+#define TVM_TOP_TENSOR_INTRIN_H_
+
+#include <tvm/top/tensor.h>
+#include <tvm/buffer.h>
+
+#include <string>
+
+
+namespace tvm {
+namespace top {
+
+// Internal node container of tensor intrinsics.
+class TensorIntrinNode;
+
+/*! \brief Tensor intrinsic node. */
+class TensorIntrin : public ObjectRef {
+ public:
+ TensorIntrin() {}
+ explicit TensorIntrin(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief access the internal node container
+ * \return the pointer to the internal node container
+ */
+ inline const TensorIntrinNode* operator->() const;
+
+ /*! \brief specify container node */
+ using ContainerType = TensorIntrinNode;
+};
+
+/*! \brief Node to represent a Tensor intrinsic operator */
+class TensorIntrinNode : public Object {
+ public:
+ /*! \brief The name of the intrinsic */
+ std::string name;
+ /*! \brief The operation this intrinsics is carrying out */
+ Operation op;
+ /*! \brief List of inputs of operator, placeholder in postdfs order */
+ Array<Tensor> inputs;
+ /*!
+ * \brief Symbolic buffers of each output/input tensor
+ * buffers[0:len(inputs)] are buffers of the inputs.
+ * buffers[len(inputs):] are buffers of each output.
+ *
+ * \note When a field in Buffer is Var, it means we can be flexible
+ * wrt that field and Var can occur in body.
+ * When it is a constant, it means we can only take data in that shape.
+ */
+ Array<Buffer> buffers;
+ /*! \brief List of scalar variables, used in body. These placeholders
+ * will be bound to expressions passed in when the TensorIntrin is called
+ * from a TensorComputeOp.
+ */
+ Array<Var> scalar_params;
+ /*! \brief The normal statement to execute the intrinsic */
+ Stmt body;
+ /*!
+ * \brief Special statement for reduction op, can be None
+ * reset the value of output buffer to identity value.
+ */
+ Stmt reduce_init;
+ /*!
+ * \brief Special statement for reduction op, can be None
+ * Reduce: do a reduction of current output buffer with the result.
+ */
+ Stmt reduce_update;
+ /*! \brief constructor */
+ TensorIntrinNode() {}
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("op", &op);
+ v->Visit("inputs", &inputs);
+ v->Visit("buffers", &buffers);
+ v->Visit("scalar_params", &scalar_params);
+ v->Visit("body", &body);
+ v->Visit("reduce_init", &reduce_init);
+ v->Visit("reduce_update", &reduce_update);
+ }
+
+ TVM_DLL static TensorIntrin make(std::string name,
+ Operation op,
+ Array<Tensor> inputs,
+ Array<Buffer> buffers,
+ Array<Var> scalar_params,
+ Stmt body,
+ Stmt reduce_init,
+ Stmt reduce_update);
+
+ static constexpr const char* _type_key = "TensorIntrin";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
+};
+
+inline const TensorIntrinNode* TensorIntrin::operator->() const {
+ return static_cast<const TensorIntrinNode*>(get());
+}
+
+// Internal node container of tensor intrinsic calling.
+class TensorIntrinCallNode;
+
+/*! \brief Tensor intrinsic calling node. */
+class TensorIntrinCall : public ObjectRef {
+ public:
+ TensorIntrinCall() {}
+ explicit TensorIntrinCall(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief access the internal node container
+ * \return the pointer to the internal node container
+ */
+ inline const TensorIntrinCallNode* operator->() const;
+
+ /*! \brief specify container node */
+ using ContainerType = TensorIntrinCallNode;
+};
+
+class TensorIntrinCallNode : public Object {
+ public:
+ /*! \brief the tensor intrinsic */
+ TensorIntrin intrin;
+ /*! \brief input tensors of the intrinsic */
+ Array<Tensor> tensors;
+ /*! \brief regions of input tensors */
+ Array<Region> regions;
+
+
+ /*!
+ * \brief IterVar on each reduction axis, if the
+ * intrin will use the reduce axis
+ */
+ Array<IterVar> reduce_axis;
+
+ /*! \brief scalar expression inputs */
+ Array<PrimExpr> scalar_inputs;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("intrin", &intrin);
+ v->Visit("tensors", &tensors);
+ v->Visit("regions", ®ions);
+ v->Visit("reduce_axis", &reduce_axis);
+ v->Visit("scalar_inputs", &scalar_inputs);
+ }
+ static TensorIntrinCall make(TensorIntrin intrin,
+ Array<Tensor> tensors,
+ Array<Region> regions,
+ Array<IterVar> reduce_axis,
+ Array<PrimExpr> scalar_inputs);
+
+ static constexpr const char* _type_key = "TensorIntrinCall";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object);
+};
+
+inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
+ return static_cast<const TensorIntrinCallNode*>(get());
+}
+
+} // namespace top
+} // namespace tvm
+#endif // TVM_TOP_TENSOR_INTRIN_H_
- support: Internal support utilities.
- runtime: Minimum runtime related codes.
- node: base infra for IR/AST nodes that is dialect independent.
-- api: API function registration.
-- lang: The definition of DSL related data structure.
- arith: Arithmetic expression and set simplification.
-- op: The detail implementations about each operation(compute, scan, placeholder).
-- schedule: The operations on the schedule graph before converting to IR.
+- top: tensor operation DSL for compute and schedule.
+- relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks.
- pass: The optimization pass on the IR structure.
- codegen: The code generator.
- autotvm: The auto-tuning module.
-- relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks.
- contrib: Contrib extension libraries.
+- api: API function registration.
+- lang: The definition of DSL related data structure.
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
-#include <tvm/tensor.h>
+#include <tvm/top/tensor.h>
namespace tvm {
namespace arith {
*/
#include <dmlc/memory_io.h>
#include <tvm/expr.h>
-#include <tvm/tensor.h>
+#include <tvm/top/tensor.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
-#include <tvm/tensor.h>
-#include <tvm/operation.h>
+#include <tvm/top/tensor.h>
+#include <tvm/top/operation.h>
#include <tvm/buffer.h>
-#include <tvm/schedule.h>
+#include <tvm/top/schedule.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape")
.set_body_method(&BijectiveLayout::BackwardShape);
+namespace top {
TVM_REGISTER_GLOBAL("_Tensor")
.set_body_typed(TensorNode::make);
TVM_REGISTER_GLOBAL("_ScheduleRFactor")
.set_body_method(&Schedule::rfactor);
+} // namespace top
TVM_REGISTER_GLOBAL("_CommReducerCombine")
.set_body_method<ir::CommReducer>(&ir::CommReducerNode::operator());
TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
.set_body_typed
- ([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) {
+ ([](const Stmt& stmt,
+ const top::Schedule& schedule,
+ const Map<top::Tensor, Buffer>& extern_buffer) {
return RewriteForTensorCore(stmt, schedule, extern_buffer);
});
* \file api_schedule.cc
*/
#include <tvm/expr.h>
-#include <tvm/tensor.h>
-#include <tvm/schedule.h>
-#include <tvm/schedule_pass.h>
+#include <tvm/top/tensor.h>
+#include <tvm/top/schedule.h>
+#include <tvm/top/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
-#include "../schedule/graph.h"
+#include "../top/schedule/graph.h"
namespace tvm {
-namespace schedule {
+namespace top {
TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
.set_body_typed(AutoInlineElemWise);
REGISTER_SCHEDULE_PASS(ScanGetBody);
REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis);
-} // namespace schedule
+} // namespace top
} // namespace tvm
* \file api_test.cc
*/
#include <tvm/expr.h>
-#include <tvm/tensor.h>
+#include <tvm/top/tensor.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/env_func.h>
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
-#include <tvm/tensor.h>
+#include <tvm/top/tensor.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
// Find Read region of the tensor in the stmt.
class FuncTouchedDomain final : public StmtExprVisitor {
public:
- FuncTouchedDomain(const Tensor &tensor, bool consider_calls, bool consider_provides)
+ FuncTouchedDomain(const top::Tensor &tensor, bool consider_calls, bool consider_provides)
: tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {}
Domain Find(const Stmt& stmt) {
}
}
- const Tensor &tensor_;
+ const top::Tensor &tensor_;
bool consider_calls_, consider_provides_;
std::vector<std::vector<IntSet> > bounds_;
std::unordered_map<const VarNode*, IntSet> dom_map_;
};
-Domain DomainTouched(Stmt stmt, const Tensor &tensor, bool consider_calls, bool consider_provides) {
+Domain DomainTouched(Stmt stmt,
+ const top::Tensor &tensor,
+ bool consider_calls,
+ bool consider_provides) {
return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt);
}
*/
#include <dmlc/thread_local.h>
#include <tvm/build_module.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/ir_pass.h>
#include <tvm/codegen.h>
#include <tvm/runtime/registry.h>
data_alignment, offset_factor, buffer_type);
}
-void GetBinds(const Array<Tensor>& args,
+void GetBinds(const Array<top::Tensor>& args,
bool compact,
- const std::unordered_map<Tensor, Buffer>& binds,
- Map<Tensor, Buffer>* out_binds,
+ const std::unordered_map<top::Tensor, Buffer>& binds,
+ Map<top::Tensor, Buffer>* out_binds,
Array<ObjectRef>* out_arg_list,
const BuildConfig& config) {
*out_binds = binds;
* \param config The build configuration.
* \return The built Stmt.
*/
-Stmt BuildStmt(Schedule sch,
- const Array<Tensor>& args,
- const std::unordered_map<Tensor, Buffer>& binds,
+Stmt BuildStmt(top::Schedule sch,
+ const Array<top::Tensor>& args,
+ const std::unordered_map<top::Tensor, Buffer>& binds,
bool loop_partition,
Array<ObjectRef> *out_arg_list,
const BuildConfig& config) {
sch = sch.normalize();
// Phase 0
- auto bounds = schedule::InferBound(sch);
- auto stmt = schedule::ScheduleOps(sch, bounds, false);
+ auto bounds = top::InferBound(sch);
+ auto stmt = top::ScheduleOps(sch, bounds, false);
stmt = ir::InjectPrefetch(stmt);
bool compact = ir::VerifyCompactBuffer(stmt);
- Map<Tensor, Buffer> out_binds;
+ Map<top::Tensor, Buffer> out_binds;
GetBinds(args, compact, binds, &out_binds, out_arg_list, config);
// Phase 1
return stmt;
}
-Array<LoweredFunc> lower(Schedule sch,
- const Array<Tensor>& args,
+Array<LoweredFunc> lower(top::Schedule sch,
+ const Array<top::Tensor>& args,
const std::string& name,
- const std::unordered_map<Tensor, Buffer>& binds,
+ const std::unordered_map<top::Tensor, Buffer>& binds,
const BuildConfig& config) {
Array<ObjectRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
-#include <tvm/schedule.h>
+#include <tvm/top/schedule.h>
#include <map>
#include <string>
#include <unordered_map>
namespace tvm {
namespace contrib {
+using namespace top;
using namespace ir;
/*!
* \brief A base class to generate Hybrid Script.
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tensor.cc
- */
-#include <tvm/tensor.h>
-#include <tvm/operation.h>
-#include <tvm/tensor_intrin.h>
-#include <memory>
-
-namespace tvm {
-
-// Tensor
-PrimExpr Tensor::operator()(Array<Var> indices) const {
- Array<PrimExpr> arr(indices.begin(), indices.end());
- return operator()(arr);
-}
-
-PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
- using ir::CallNode;
- if (ndim() != 0) {
- CHECK_EQ(ndim(), indices.size())
- << "Tensor dimension mismatch in read"
- << "ndim = " << ndim() << ", indices.size=" << indices.size();
- }
- auto n = CallNode::make(
- (*this)->dtype, (*this)->op->name, indices, CallNode::Halide,
- (*this)->op, (*this)->value_index);
- return n;
-}
-
-Tensor Operation::output(size_t i) const {
- auto node = make_object<TensorNode>();
- node->op = *this;
- node->value_index = i;
- node->dtype = (*this)->output_dtype(i);
- node->shape = (*this)->output_shape(i);
- return Tensor(node);
-}
-
-Tensor TensorNode::make(Array<PrimExpr> shape,
- DataType dtype,
- Operation op,
- int value_index) {
- auto n = make_object<TensorNode>();
- n->shape = std::move(shape);
- n->dtype = dtype;
- n->op = op;
- n->value_index = value_index;
- return Tensor(n);
-}
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TensorNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* t = static_cast<const TensorNode*>(node.get());
- p->stream << "Tensor(shape=" << t->shape
- << ", op.name=" << t->op->name << ')';
- });
-
-TVM_REGISTER_NODE_TYPE(TensorNode);
-
-
-// TensorIntrin
-
-TensorIntrin TensorIntrinNode::make(std::string name,
- Operation op,
- Array<Tensor> inputs,
- Array<Buffer> buffers,
- Array<Var> scalar_params,
- Stmt body,
- Stmt reduce_init,
- Stmt reduce_update) {
- auto n = make_object<TensorIntrinNode>();
- n->name = std::move(name);
- n->op = std::move(op);
- n->inputs = std::move(inputs);
- n->buffers = std::move(buffers);
- n->scalar_params = std::move(scalar_params);
- n->body = std::move(body);
- n->reduce_init = std::move(reduce_init);
- n->reduce_update = std::move(reduce_update);
- return TensorIntrin(n);
-}
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const TensorIntrinNode*>(node.get());
- p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
- });
-
-TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
-
-
-// TensorIntrinCall
-
-TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
- Array<Tensor> tensors,
- Array<Region> regions,
- Array<IterVar> reduce_axis,
- Array<PrimExpr> scalar_inputs) {
- auto n = make_object<TensorIntrinCallNode>();
- n->intrin = std::move(intrin);
- n->tensors = std::move(tensors);
- n->regions = std::move(regions);
- n->reduce_axis = std::move(reduce_axis);
- n->scalar_inputs = std::move(scalar_inputs);
- return TensorIntrinCall(n);
-}
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
- p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
- });
-
-TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
-
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Compute Op.
- * \file compute_op.cc
- */
-#include <tvm/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <tvm/ir_functor_ext.h>
-#include <unordered_set>
-#include <string>
-#include <utility>
-#include "compute_op.h"
-#include "op_util.h"
-#include "../schedule/message_passing.h"
-#include "../arith/compute_expr.h"
-#include "../arith/interval_set.h"
-
-namespace tvm {
-
-using namespace ir;
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<ComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const ComputeOpNode*>(node.get());
- p->stream << "compute(" << op->name << ", " << op << ")";
-});
-
-TVM_REGISTER_NODE_TYPE(ComputeOpNode);
-
-/// Verify if ComputeOp is valid with respect to Reduce operations.
-static void VerifyComputeOp(const ComputeOpNode *op);
-
-inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
- return (a->combiner.same_as(b->combiner)) &&
- (a->source.same_as(b->source)) &&
- (a->axis.same_as(b->axis)) &&
- (a->condition.same_as(b->condition));
-}
-
-int ComputeOpNode::num_outputs() const {
- return body.size();
-}
-
-Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
- if (reduce_axis.size() == 0) return axis;
- Array<IterVar> ret = axis;
- for (IterVar iv : reduce_axis) {
- ret.push_back(iv);
- }
- return ret;
-}
-
-DataType ComputeOpNode::output_dtype(size_t idx) const {
- CHECK_LT(idx, num_outputs());
- return body[idx].dtype();
-}
-
-Array<PrimExpr> BaseComputeOpNode::output_shape(size_t idx) const {
- CHECK_LT(idx, num_outputs());
- // for now, all outputs of a BaseComputeOp have the same shape
- Array<PrimExpr> shape;
- for (const auto& ivar : this->axis) {
- const Range& r = ivar->dom;
- shape.push_back(r->extent);
- }
- return shape;
-}
-
-Tensor compute(Array<PrimExpr> shape,
- FCompute fcompute,
- std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs) {
- auto op_node = make_object<ComputeOpNode>();
- // compute dimension.
- size_t ndim = shape.size();
- std::vector<IterVar> axis;
- std::vector<Var> args;
- for (size_t i = 0; i < ndim; ++i) {
- std::ostringstream os;
- os << "ax" << i;
- axis.emplace_back(IterVarNode::make(
- Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
- args.push_back(axis.back()->var);
- }
-
- return ComputeOpNode::make(
- name, tag, attrs, axis, {fcompute(args)}).output(0);
-}
-
-Array<Tensor> compute(Array<PrimExpr> shape,
- FBatchCompute fcompute,
- std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs) {
- auto op_node = make_object<ComputeOpNode>();
- // compute dimension.
- size_t ndim = shape.size();
- std::vector<IterVar> axis;
- std::vector<Var> args;
- for (size_t i = 0; i < ndim; ++i) {
- std::ostringstream os;
- os << "ax" << i;
- axis.emplace_back(IterVarNode::make(
- Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
- args.push_back(axis.back()->var);
- }
-
- Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
- Array<Tensor> outputs;
- for (int idx = 0; idx < op->num_outputs(); ++idx) {
- outputs.push_back(op.output(idx));
- }
- return outputs;
-}
-
-Operation ComputeOpNode::make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<IterVar> axis,
- Array<PrimExpr> body) {
- if (!attrs.defined()) {
- attrs = Map<std::string, ObjectRef>();
- }
- auto n = make_object<ComputeOpNode>();
- n->name = std::move(name);
- n->tag = std::move(tag);
- n->attrs = std::move(attrs);
- n->axis = std::move(axis);
- n->body = std::move(body);
- if (n->body[0]->IsInstance<ir::ReduceNode>()) {
- const ir::ReduceNode* reduce = n->body[0].as<ir::ReduceNode>();
- n->reduce_axis = reduce->axis;
- }
- VerifyComputeOp(n.get());
- return Operation(n);
-}
-
-// The schedule related logics
-Array<Tensor> ComputeOpNode::InputTensors() const {
- Array<Tensor> ret;
- std::unordered_set<Tensor> visited;
- for (auto& e : body) {
- ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
- const ir::CallNode *call = n.as<ir::CallNode>();
- if (call != nullptr && call->func.defined()) {
- Tensor t = Downcast<Operation>(call->func).output(call->value_index);
- if (!visited.count(t)) {
- ret.push_back(t);
- visited.insert(t);
- }
- }
- });
- }
- return ret;
-}
-
-Operation ComputeOpNode::ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const {
- CHECK_EQ(self.operator->(), this);
- VerifyComputeOp(this);
- Array<PrimExpr> arr;
- if (this->body[0]->IsInstance<ir::ReduceNode>()) {
- // Specially handle reduce so the replaced op
- // still share all the components
- PrimExpr new_reduce = op::ReplaceTensor(this->body[0], rmap);
- if (!new_reduce.same_as(this->body[0])) {
- const ir::ReduceNode* r = new_reduce.as<ir::ReduceNode>();
- for (size_t k = 0; k < this->body.size(); ++k) {
- auto n = make_object<ir::ReduceNode>(*r);
- n->value_index = static_cast<int>(k);
- n->dtype = r->source[k].dtype();
- arr.push_back(PrimExpr(n));
- }
- } else {
- arr = this->body;
- }
- } else {
- arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) {
- return op::ReplaceTensor(e, rmap);
- });
- }
- if (!arr.same_as(this->body)) {
- return ComputeOpNode::make(
- this->name, this->tag, this->attrs, this->axis, arr);
- } else {
- return self;
- }
-}
-
-void ComputeOpNode::PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
- CHECK_EQ(self.operator->(), this);
- auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
- auto *call = n.as<ir::CallNode>();
- if (call != nullptr && call->func.defined()) {
- Tensor t = Downcast<Operation>(call->func).output(call->value_index);
- if (t->op.defined() && out_dom_map->count(t)) {
- TensorDom& dom = out_dom_map->at(t);
- for (size_t i = 0; i < t.ndim(); ++i) {
- // We assume that the value of the argument cannot be out of bounds (otherwise it is
- // undefined behaviour), so we can intersect the estimated set of the argument with the
- // range expected by the tensor. However, intersection may result in overly complex
- // expressions, so we perform a more relaxed form of intersection.
- IntSet arg_intset = EvalSet(call->args[i], dom_map);
- const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
- if (arg_interval) {
- PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
- PrimExpr shape_i_max_value = t->shape[i] - 1;
- PrimExpr min_value = arg_interval->min_value;
- PrimExpr max_value = arg_interval->max_value;
- // Prefer the shape bounds only when we can prove they are tighter.
- if (arith::is_neg_inf(min_value) ||
- analyzer->CanProve(shape_i_min_value >= min_value)) {
- min_value = shape_i_min_value;
- }
- if (arith::is_pos_inf(max_value) ||
- analyzer->CanProve(shape_i_max_value <= max_value)) {
- max_value = shape_i_max_value;
- }
- dom.data[i].push_back(IntSet::interval(min_value, max_value));
- } else {
- dom.data[i].push_back(arg_intset);
- }
- }
- }
- }
- };
- for (auto& e : body) ir::PostOrderVisit(e, fvisit);
-}
-
-void BaseComputeOpNode::GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const {
- CHECK_EQ(self.operator->(), this);
- const TensorDom& tdom = tensor_dom.at(self.output(0));
- for (size_t i = 0; i < this->axis.size(); ++i) {
- Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
- CHECK(!out_dom_map->count(this->axis[i]));
- (*out_dom_map)[this->axis[i]] = r;
- }
- for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
- CHECK(!out_dom_map->count(this->reduce_axis[i]));
- (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
- }
-}
-
-Stmt BaseComputeOpNode::BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const {
- CHECK_EQ(stage->op.get(), this);
- Region bounds;
- for (IterVar iv : this->axis) {
- bounds.push_back(realize_map.at(iv));
- }
- Stmt realize = body;
- for (int i = this->num_outputs(); i > 0; --i) {
- Tensor t = stage->op.output(i-1);
- realize = ir::RealizeNode::make(t->op, t->value_index,
- t->dtype, bounds, const_true(), realize);
- // alignment requirement, only useful for compute
- for (size_t i = 0; i < num_schedulable_dims(); ++i) {
- auto it = stage->iter_var_attrs.find(this->axis[i]);
- if (it != stage->iter_var_attrs.end()) {
- IterVarAttr attr = (*it).second;
- if (attr->dim_align_factor != 0) {
- Array<PrimExpr> tuple = {static_cast<int>(i),
- attr->dim_align_factor,
- attr->dim_align_offset};
- realize = ir::AttrStmtNode::make(
- t, ir::attr::buffer_dim_align,
- CallNode::make(DataType::Handle(),
- ir::intrinsic::tvm_tuple,
- tuple, CallNode::Intrinsic),
- realize);
- }
- }
- }
- }
- return realize;
-}
-
-size_t ComputeOpNode::num_schedulable_dims() const {
- return axis.size();
-}
-
-// Build a reduction body.
-void MakeReduction(const ComputeOpNode* op,
- const Array<Tensor>& tensors,
- Stmt* init,
- Stmt* provide) {
- Array<PrimExpr> args;
- for (IterVar iv : op->axis) {
- args.push_back(iv->var);
- }
- std::vector<Stmt> inits, provides;
-
- size_t size = op->body.size();
- const ReduceNode* reduce = op->body[0].as<ReduceNode>();
- CHECK(reduce);
- const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
- CHECK(combiner);
- Array<PrimExpr> lhs;
- for (size_t i = 0; i < size; ++i) {
- lhs.push_back(tensors[i](args));
- }
- Array<PrimExpr> init_value = combiner->identity_element;
- Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
- for (size_t i = 0; i < size; ++i) {
- Tensor t = tensors[i];
- inits.emplace_back(ProvideNode::make(
- t->op, t->value_index, init_value[i], args));
- provides.emplace_back(ProvideNode::make(
- t->op, t->value_index, update_value[i], args));
- }
- *init = SeqStmt::Flatten(inits);
- *provide = SeqStmt::Flatten(provides);
- if (!is_one(reduce->condition)) {
- *provide = IfThenElseNode::make(reduce->condition, *provide);
- }
-}
-
-// Normal computation.
-Stmt MakeProvide(const ComputeOpNode* op,
- const Tensor& t) {
- Array<PrimExpr> args;
- for (IterVar iv : op->axis) {
- args.push_back(iv->var);
- }
- return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args);
-}
-
-Stmt MakeComputeStmt(const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) {
- // grab the nest structure
- ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
- // Normal loop structure
- n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
- n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
- if (self->reduce_axis.size() != 0) {
- // make reduction.
- Stmt init, provide;
- Array<Tensor> source;
- for (size_t i = 0; i < self->body.size(); ++i) {
- source.push_back(stage->op.output(i));
- }
- MakeReduction(self, source, &init, &provide);
- init = MergeNest(n.init_nest, init);
- init = op::Substitute(init, n.init_vmap);
- // common nest
- std::vector<std::vector<Stmt> > common(
- n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
- std::vector<std::vector<Stmt> > reduce(
- n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
- provide = MergeNest(reduce, provide);
- if (debug_keep_trivial_loop) {
- provide = MergeNest(common, provide);
- } else {
- provide = MergeNest(common, SeqStmt::Flatten(init, provide));
- }
- // run substitution in the on the full nest, because loop condition
- // could depend on outer loops.
- return op::Substitute(provide, n.main_vmap);
- } else {
- std::vector<Stmt> provides;
- for (size_t i = 0; i < self->body.size(); ++i) {
- provides.emplace_back(MakeProvide(self, stage->op.output(i)));
- }
- Stmt provide = SeqStmt::Flatten(provides);
- provide = MergeNest(n.main_nest, provide);
- // run substitution in the on the full nest, because loop condition
- // could depend on outer loops.
- return op::Substitute(provide, n.main_vmap);
- }
-}
-
-enum class ComputeType {
- kNormal,
- kCrossThreadReduction,
- kTensorize
-};
-
-ComputeType DetectComputeType(const ComputeOpNode* self,
- const Stage& stage) {
- // Verify correctness of leaf nest.
- int normal_red = 0, thread_red = 0, tensorize = 0;
-
- for (IterVar iv : stage->leaf_iter_vars) {
- IterVarAttr attr;
- auto it = stage->iter_var_attrs.find(iv);
- if (it != stage->iter_var_attrs.end()) {
- attr = (*it).second;
- }
- if (attr.defined() && attr->iter_type == kTensorized) {
- ++tensorize;
- }
- if (iv->iter_type == kCommReduce) {
- if (attr.defined() && attr->bind_thread.defined()) {
- ++thread_red;
- } else {
- ++normal_red;
- }
- } else {
- CHECK_EQ(thread_red, 0)
- << "Cross thread reduce cannot swap with normal data axis";
- }
- }
- if (tensorize != 0) {
- CHECK(thread_red == 0)
- << "Cannot mix cross thread reduction with Tensorize";
- return ComputeType::kTensorize;
- }
- CHECK(normal_red == 0 || thread_red == 0)
- << "Cannot mix normal reduction with thread reduce";
- if (thread_red != 0) {
- return ComputeType::kCrossThreadReduction;
- } else {
- return ComputeType::kNormal;
- }
-}
-
-// implement the provide utility.
-Stmt ComputeOpNode::BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const {
- CHECK_EQ(stage->op.operator->(), this);
- ComputeType ctype = DetectComputeType(this, stage);
- if (ctype == ComputeType::kCrossThreadReduction) {
- // specially handle cross thread reduction.
- return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
- } else if (ctype == ComputeType::kTensorize) {
- return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
- } else {
- return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
- }
-}
-
-ComputeLoopNest ComputeLoopNest::make(
- const BaseComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) {
- CHECK_EQ(stage->op.operator->(), self);
- ComputeLoopNest ret;
- // make main loop nest
- ret.main_nest = op::MakeLoopNest(
- stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
- debug_keep_trivial_loop);
- ret.main_predicates = schedule::MakeBoundCheck(
- stage, dom_map, ret.main_vmap, false,
- std::unordered_set<IterVar>());
- for (auto& e : ret.main_predicates) {
- e = likely(e);
- }
- if (stage->store_predicate.defined()) {
- ret.main_predicates.push_back(stage->store_predicate);
- }
- if (self->reduce_axis.size() != 0) {
- // try to find the location to insert the initialization.
- // Fuse the initialization and provide loop when possible.
- std::unordered_map<IterVar, int> update_state;
- for (IterVar iv : self->reduce_axis) {
- update_state[iv] = 2;
- }
- for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
- update_state[self->axis[i]] = 1;
- }
- // find which iter var is related to reduction and which is related to axis.
- schedule::PassDownBitMaskOr(stage, &update_state);
- auto leaf_iter_vars = stage->leaf_iter_vars;
- // first first loop that is related to reduction.
- size_t begin_loop = leaf_iter_vars.size();
- for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
- auto iv = leaf_iter_vars[i];
- int flag = update_state.at(iv);
- if ((flag & 2) != 0) {
- begin_loop = i; break;
- }
- ret.init_vmap[iv] = ret.main_vmap.at(iv);
- }
- ret.num_common_loop = begin_loop;
- // skip loops that are related to reduction and are unrelated to axis.
- std::unordered_set<IterVar> skip_iter;
- for (auto kv : update_state) {
- int flag = kv.second;
- if (flag == 2) skip_iter.insert(kv.first);
- }
- ret.init_nest = op::MakeLoopNest(
- stage, dom_map, begin_loop, true,
- skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
- ret.init_predicates = schedule::MakeBoundCheck(
- stage, dom_map, ret.init_vmap, true, skip_iter);
- for (auto& e : ret.init_predicates) {
- e = likely(e);
- }
- } else {
- CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
- ret.num_common_loop = stage->leaf_iter_vars.size();
- }
- // copy elison here.
- return ret;
-}
-
-namespace {
-/*!
- * \brief Verify if ComputeOp is valid with respect to Reduce operations.
- *
- * The following two properties are verified:
- * (1) All Reduce operations must exist at top level.
- * (2) For a list of operations, if one is Reduce, then the others
- * must be Reduce as well; and their inputs should have the
- * same attribute except value_index.
- */
-class ComputeVerifier final : protected ir::ExprVisitor {
- public:
- /// Special member functions
- //@{
- explicit ComputeVerifier(const ComputeOpNode* compute)
- : compute_(compute), reduce_(compute->body[0].as<ir::ReduceNode>()) {}
- virtual ~ComputeVerifier() = default;
- ComputeVerifier(const ComputeVerifier&) = delete;
- ComputeVerifier(ComputeVerifier&&) = delete;
- ComputeVerifier& operator=(const ComputeVerifier&) = delete;
- ComputeVerifier& operator=(ComputeVerifier&&) = delete;
- //@}
-
- /// Interface to perform compute verification
- void Run() {
- for (const PrimExpr e : compute_->body) {
- // Check for consistency of top level reductions
- const ir::ReduceNode* reduce = e.as<ir::ReduceNode>();
- CHECK((reduce && reduce_) || (!reduce && !reduce_))
- << "All ComputeOp should be consistent "
- << "with being Reduce operation or not.";
-
- if (reduce && reduce_) {
- CHECK(ReduceEqual(reduce, reduce_))
- << "The Reduce inputs of ComputeOp should "
- << "have the same attribute except value_index";
- }
-
- level_ = 0;
- ExprVisitor::VisitExpr(e);
- }
- }
-
- protected:
- /// Visitor implementation
- //@{
- void VisitExpr(const PrimExpr& n) final {
- ++level_;
- ExprVisitor::VisitExpr(n);
- --level_;
- }
-
- void VisitExpr_(const ir::ReduceNode* op) final {
- // Check for non top level reductions
- CHECK(0 == level_)
- << "Reductions are only allowed at the top level of compute. "
- << "Please create another tensor for further composition.";
- }
- //@}
-
- private:
- const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
- const ir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation
- int level_{0}; ///< Level of op being processed
-};
-} // namespace
-
-/// Verify if ComputeOp is valid with respect to Reduce operations.
-static void VerifyComputeOp(const ComputeOpNode* op) {
- ComputeVerifier v(op);
- v.Run();
-}
-
-Stmt TransformUpdate(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const ComputeLoopNest& n,
- Stmt body,
- Stmt update) {
- Array<PrimExpr> conds;
- std::unordered_set<const VarNode*> banned;
- for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
- IterVar iv = stage->leaf_iter_vars[i];
- auto iit = stage->iter_var_attrs.find(iv);
- if (iit != stage->iter_var_attrs.end()) {
- const IterVarAttr& attr = (*iit).second;
- if (attr->iter_type == kTensorized) {
- break;
- }
- }
- if (iv->iter_type == kCommReduce) {
- auto vit = dom_map.find(iv);
- CHECK(vit != dom_map.end());
- const Range& vrange = vit->second;
- conds.push_back(likely(iv->var > vrange->min));
- banned.insert(iv->var.get());
- }
- }
- for (const PrimExpr& pred : n.main_predicates) {
- if (ir::ExprUseVar(pred, banned)) {
- LOG(FATAL) << "Tensorize update transform failed, the condition "
- << pred << " has a conflict with the reset condition";
- }
- }
-
- return IfThenElseNode::make(arith::ComputeReduce<ir::OrNode>(conds, const_true(1)),
- update, body);
-}
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Helper utilities to implement compute_op.
- * \file compute_op.h
- */
-#ifndef TVM_OP_COMPUTE_OP_H_
-#define TVM_OP_COMPUTE_OP_H_
-
-#include <tvm/ir.h>
-#include <tvm/expr.h>
-#include <tvm/operation.h>
-#include <vector>
-#include <unordered_map>
-
-namespace tvm {
-// loop nest structure for general compute
-// This the loop nest structured used in compute.
-// Does not include the loop body.
-struct ComputeLoopNest {
- // The common number of loops between init and main
- size_t num_common_loop;
- // predicates for the initialize loop
- std::vector<PrimExpr> init_predicates;
- // Initialization nest involved.
- std::vector<std::vector<Stmt> > init_nest;
- // Value map for the init code
- std::unordered_map<IterVar, PrimExpr> init_vmap;
- // Predicates for the main update loop
- std::vector<PrimExpr> main_predicates;
- // The general loop nest
- std::vector<std::vector<Stmt> > main_nest;
- // Value map for the IterVar.
- std::unordered_map<IterVar, PrimExpr> main_vmap;
-
- /*!
- * \brief constructor to build ComputeOpNest
- * \param self The pointer to compute op.
- * \param stage The scxhedule stage.
- * \param dom_map The domain map.
- * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
- * \return The constructed loop nest
- */
- static ComputeLoopNest make(
- const BaseComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop);
-};
-
-/*!
- * \brief Build body of compute for cross thread reduction pattern.
- * \param self The pointer to ComputeOpNode
- * \param stage The schedule stage.
- * \param dom_map The domain map.
- * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
- * \return The created statement.
- */
-Stmt MakeCrossThreadReduction(
- const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop);
-
-/*!
- * \brief Build body of compute for tensorization.
- * \param self The pointer to ComputeOpNode
- * \param stage The schedule stage.
- * \param dom_map The domain map.
- * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
- * \return The created statement.
- */
-Stmt MakeTensorize(const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop);
-
-/*!
- * \brief Transform the update part when there is no init func in tensorizing
- * \param stage The stage for tensorizing.
- * \param dom_map The range of each iter var.
- * \param n The loop nest structured used in compute.
- * \param body The body func in tensorize intrin
- * \param update The update func in tensorize intrin
- * \return Transformed result.
- */
-Stmt TransformUpdate(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const ComputeLoopNest& n,
- Stmt body,
- Stmt update);
-} // namespace tvm
-
-#endif // TVM_OP_COMPUTE_OP_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Logics related to cross thread reduction, used by ComputeOpNode.
- * \file cross_thread_reduction.cc
- */
-#include <tvm/ir_pass.h>
-#include "compute_op.h"
-#include "op_util.h"
-
-namespace tvm {
-using namespace ir;
-
-Stmt MakeCrossThreadReduction(
- const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) {
- Array<PrimExpr> args;
- for (IterVar iv : self->axis) {
- args.push_back(iv->var);
- }
- std::unordered_map<IterVar, PrimExpr> value_map;
- auto nest = op::MakeLoopNest(
- stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop);
- auto conds = schedule::MakeBoundCheck(
- stage, dom_map, value_map, false,
- std::unordered_set<IterVar>());
-
- size_t size = self->body.size();
- CHECK_GT(size, 0);
- std::vector<const ReduceNode*> reduces(size);
- for (size_t i = 0; i < size; ++i) {
- const ReduceNode* reduce = self->body[i].as<ReduceNode>();
- CHECK(reduce);
- reduces[i] = reduce;
- }
- PrimExpr cond = reduces[0]->condition;
- for (PrimExpr v : conds) {
- cond = cond && v;
- }
- Array<PrimExpr> freduce_args;
- freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
- for (size_t i = 0; i < size; ++i) {
- freduce_args.push_back(reduces[0]->source[i]);
- }
- freduce_args.push_back(cond);
- std::vector<Var> res_handles(size);
- for (size_t idx = 0; idx < size; ++idx) {
- res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle());
- freduce_args.push_back(res_handles[idx]);
- }
-
- for (IterVar iv : stage->leaf_iter_vars) {
- if (iv->iter_type == kCommReduce) {
- auto it = stage->iter_var_attrs.find(iv);
- if (it != stage->iter_var_attrs.end() &&
- (*it).second->bind_thread.defined()) {
- IterVar tv = (*it).second->bind_thread;
- freduce_args.push_back(tv->var);
- }
- }
- }
- // Checks for the thread.
- std::vector<PrimExpr> thread_head_check;
- if (stage->store_predicate.defined()) {
- thread_head_check.emplace_back(stage->store_predicate);
- }
-
- Stmt reduce_body = EvaluateNode::make(CallNode::make(
- DataType::Handle(),
- ir::intrinsic::tvm_thread_allreduce,
- freduce_args, CallNode::Intrinsic));
- reduce_body = AttrStmtNode::make(
- reduces[0]->combiner,
- attr::reduce_scope,
- make_zero(DataType::Handle()),
- reduce_body);
- std::vector<Stmt> assigns(size);
- for (size_t idx = 0; idx < size; ++idx) {
- DataType t = reduces[idx]->dtype;
- assigns[idx] = ProvideNode::make(
- stage->op, idx,
- LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
- }
- Stmt assign_body = SeqStmt::Flatten(assigns);
- assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
- assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
- Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
- for (size_t idx = size; idx != 0; --idx) {
- body = AllocateNode::make(
- res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
- body = AttrStmtNode::make(
- res_handles[idx - 1], attr::storage_scope, StringImmNode::make("local"), body);
- }
- body = op::Substitute(body, value_map);
- return MergeNest(nest, body);
-}
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief External computation rule.
- * \file extern_op.cc
- */
-#include <tvm/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir.h>
-#include <unordered_set>
-#include "op_util.h"
-
-namespace tvm {
-using namespace ir;
-// ExternOpNode
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<ExternOpNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const ExternOpNode*>(node.get());
- p->stream << "extern(" << op->name << ", " << op << ")";
- });
-
-TVM_REGISTER_NODE_TYPE(ExternOpNode);
-
-int ExternOpNode::num_outputs() const {
- return static_cast<int>(output_placeholders.size());
-}
-
-Array<IterVar> ExternOpNode::root_iter_vars() const {
- return {};
-}
-
-DataType ExternOpNode::output_dtype(size_t i) const {
- return output_placeholders[i]->dtype;
-}
-
-Array<PrimExpr> ExternOpNode::output_shape(size_t i) const {
- return output_placeholders[i]->shape;
-}
-
-
-Operation ExternOpNode::make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<Tensor> inputs,
- Array<Buffer> input_placeholders,
- Array<Buffer> output_placeholders,
- Stmt body) {
- if (!attrs.defined()) {
- attrs = Map<std::string, ObjectRef>();
- }
- auto n = make_object<ExternOpNode>();
- n->name = std::move(name);
- n->tag = std::move(tag);
- n->attrs = std::move(attrs);
- CHECK_EQ(inputs.size(), input_placeholders.size());
- for (size_t i = 0; i < inputs.size(); ++i) {
- CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
- CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size());
- for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) {
- CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim]));
- }
- CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
- }
- n->inputs = std::move(inputs);
- n->input_placeholders = std::move(input_placeholders);
- n->output_placeholders = std::move(output_placeholders);
- n->body = std::move(body);
- return Operation(n);
-}
-
-Array<Tensor> ExternOpNode::InputTensors() const {
- return inputs;
-}
-
-Operation ExternOpNode::ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const {
- CHECK_EQ(self.operator->(), this);
- auto n = make_object<ExternOpNode>(*this);
- n->body = op::ReplaceTensor(this->body, rmap);
- for (size_t i = 0; i < n->inputs.size(); ++i) {
- Tensor t = n->inputs[i];
- if (rmap.count(t)) {
- n->inputs.Set(i, rmap.at(t));
- }
- }
-
- if (body.same_as(n->body) &&
- inputs.same_as(n->inputs)) {
- return self;
- } else {
- return Operation(n);
- }
-}
-
-void ExternOpNode::PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
- for (Tensor t : this->inputs) {
- auto it = out_dom_map->find(t);
- if (it == out_dom_map->end()) continue;
- TensorDom& dom = it->second;
- for (size_t i = 0; i < t->shape.size(); ++i) {
- dom.data[i].emplace_back(IntSet::range(
- Range::make_by_min_extent(
- make_const(t->shape[i].dtype(), 0), t->shape[i])));
- }
- }
-}
-
-void ExternOpNode::GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const {
-}
-
-Stmt ExternOpNode::BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const {
- CHECK_EQ(stage->op.get(), this);
- Stmt realize_body = body;
- for (int k = 0; k < num_outputs(); ++k) {
- Tensor t = stage->op.output(k);
- Region bounds;
- for (size_t i = 0; i < t->shape.size(); ++i) {
- bounds.push_back(
- Range::make_by_min_extent(
- make_const(t->shape[i].dtype(), 0), t->shape[i]));
- }
- realize_body = ir::RealizeNode::make(
- t->op, t->value_index, t->dtype,
- bounds, const_true(), realize_body);
- }
- return realize_body;
-}
-
-Stmt ExternOpNode::BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const {
- CHECK_EQ(stage->op.operator->(), this);
- Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
- auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
- Array<ObjectRef> bind_spec;
- Array<PrimExpr> tuple;
- bind_spec.push_back(buffer);
- bind_spec.push_back(tensor);
- for (size_t k = 0; k < buffer->shape.size(); ++k) {
- tuple.push_back(make_const(buffer->shape[k].dtype(), 0));
- tuple.push_back(buffer->shape[k]);
- }
- ret = AttrStmtNode::make(
- bind_spec, attr::buffer_bind_scope,
- CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
- };
- for (size_t i = output_placeholders.size(); i != 0; --i) {
- f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
- }
- for (size_t i = inputs.size(); i != 0; --i) {
- f_push_bind(input_placeholders[i - 1], inputs[i - 1]);
- }
- return ret;
-}
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Hybrid computation rule.
- * \file hybrid_op.cc
- */
-#include <tvm/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/ir_pass.h>
-#include <tvm/expr_operator.h>
-#include <unordered_set>
-#include <string>
-#include <utility>
-#include "op_util.h"
-#include "hybrid_op.h"
-
-namespace tvm {
-using namespace ir;
-// HybridOpNode
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<HybridOpNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const HybridOpNode*>(node.get());
- p->stream << "hybrid(" << op->name << ", " << op << ")";
- });
-
-TVM_REGISTER_NODE_TYPE(HybridOpNode);
-
-int HybridOpNode::num_outputs() const {
- return static_cast<int>(outputs.size());
-}
-
-Array<IterVar> HybridOpNode::root_iter_vars() const {
- return this->axis;
-}
-
-DataType HybridOpNode::output_dtype(size_t i) const {
- return outputs[i]->dtype;
-}
-
-Array<PrimExpr> HybridOpNode::output_shape(size_t i) const {
- return outputs[i]->shape;
-}
-
-
-Operation HybridOpNode::make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- Array<Tensor> inputs,
- Array<Tensor> outputs,
- Stmt body) {
- if (!attrs.defined()) {
- attrs = Map<std::string, ObjectRef>();
- }
- auto n = make_object<HybridOpNode>();
- n->name = std::move(name);
- n->tag = std::move(tag);
- n->attrs = std::move(attrs);
- n->inputs = std::move(inputs);
- n->outputs = std::move(outputs);
- n->axis = op::GatherLoopVars(body);
- n->body = std::move(body);
- Operation res = Operation(n);
- return res;
-}
-
-Array<Tensor> HybridOpNode::InputTensors() const {
- // Because input tensors could be potentially inlined into hybrid scripts,
- // we need to check if all input tensors are used in the body.
- std::unordered_set<Tensor> orig_inputs;
- for (auto t : inputs) {
- orig_inputs.insert(t);
- }
- std::unordered_set<Tensor> visited;
- Array<Tensor> curr_inputs;
- ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
- const ir::CallNode *call = n.as<ir::CallNode>();
- if (call != nullptr && call->func.defined()) {
- Tensor t = Downcast<Operation>(call->func).output(call->value_index);
- if (orig_inputs.count(t) && !visited.count(t)) {
- curr_inputs.push_back(t);
- visited.insert(t);
- }
- }
- });
- return curr_inputs;
-}
-
-Operation HybridOpNode::ReplaceInputs(
- const Operation &self,
- const std::unordered_map<Tensor, Tensor> &rmap) const {
- CHECK_EQ(self.operator->(), this);
- auto n = make_object<HybridOpNode>(*this);
- n->body = op::ReplaceTensor(this->body, rmap);
- for (size_t i = 0; i < n->inputs.size(); ++i) {
- Tensor t = n->inputs[i];
- if (rmap.count(t)) {
- n->inputs.Set(i, rmap.at(t));
- }
- }
-
- if (body.same_as(n->body) &&
- inputs.same_as(n->inputs)) {
- return self;
- } else {
- return Operation(n);
- }
-}
-
-void HybridOpNode::PropBoundToInputs(
- const Operation &self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet> &dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
- auto curr_inputs = InputTensors();
- for (Tensor t : curr_inputs) {
- auto it = out_dom_map->find(t);
- if (it == out_dom_map->end()) continue;
- TensorDom &dom = it->second;
- for (size_t i = 0; i < t->shape.size(); ++i) {
- dom.data[i].emplace_back(IntSet::range(
- Range::make_by_min_extent(
- make_const(t->shape[i].dtype(), 0), t->shape[i])));
- }
- }
-}
-
-void HybridOpNode::GatherBound(
- const Operation &self,
- const std::unordered_map<Tensor, TensorDom> &tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const {
- for (auto iter_var : axis) {
- CHECK(!out_dom_map->count(iter_var));
- out_dom_map->operator[](iter_var) = iter_var->dom;
- }
-}
-
-Stmt HybridOpNode::BuildRealize(
- const Stage &stage,
- const std::unordered_map<IterVar, Range> &realize_map,
- const Stmt &body) const {
- // TODO(@were): Add attribute inject here and remove it from hybrid parser.
- CHECK_EQ(stage->op.get(), this);
- Stmt realize_body = body;
- for (int k = 0; k < num_outputs(); ++k) {
- Tensor t = stage->op.output(k);
- Region bounds;
- for (size_t i = 0; i < t->shape.size(); ++i) {
- bounds.push_back(
- Range::make_by_min_extent(
- make_const(t->shape[i].dtype(), 0), t->shape[i]));
- }
- realize_body = ir::RealizeNode::make(
- t->op, t->value_index, t->dtype,
- bounds, const_true(), realize_body);
- }
- return realize_body;
-}
-
-Stmt HybridOpNode::BuildProvide(
- const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map,
- bool debug_keep_trivial_loop) const {
- CHECK_EQ(stage->op.operator->(), this);
- Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
- std::unordered_map<Tensor, Tensor> rmap;
- for (int i = 0; i < this->num_outputs(); ++i) {
- rmap[outputs[i]] = stage->op.output(i);
- }
- auto n = make_object<HybridOpNode>(*this);
- /* This is a story little bit complicated.
- * The following two lines of codes replace output tensors' usage.
- * This is the simplest way I (@were) can come up with to glue
- * hybrid operation node to TVM op system.
- * In hybrid script all the tensors, especially the output tensors,
- * have their own names defined by the users. However, In TVM
- * conventional ops:
- * 1. Output tensors refer the corresponding op node so that the output
- * tensors have the same names as the operation produces them.
- * 2. Once OpNode is wrapped up by an Operation node, it is finalized.
- * Later access will be from a const OpNode*.
- * This is a chicken-egg paradox. It is impossible to put the output
- * tensors into the function body without forming the op node. The
- * function body is immutable after the node is formed.
- *
- * Finally, I decided to resolve this issue "lazily". During the
- * pipeline of compilation, this stage is a very preliminary stage.
- * Technically, it is before Phase 0. The actual tensors will be replaced
- * here.
- * Thus, the operation body is slightly different from the Phase 0 body.
- * This is a major difference that HybridOpNode is NOT the same as
- * ExternOpNode.
- * */
- ret = op::ReplaceTensor(ret, rmap);
- ret = op::ReplaceProvideTensor(ret, rmap);
-
- ret = op::ApplySchedule(stage, dom_map, ret);
- return ret;
-}
-
-namespace op {
-
-
-Stmt ApplyLoopShapes(const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
- class LoopSpliter : public StmtExprMutator {
- PrimExpr factor;
- const VarNode *parent;
- IterVar inner, outer;
-
- public:
- bool splitted;
- LoopSpliter(const SplitNode *split,
- const std::unordered_map<IterVar, Range> &dom_map) :
- factor(split->factor), splitted(false) {
- parent = split->parent->var.get();
-
- auto &inner_ = split->inner;
- CHECK(dom_map.count(inner_));
- auto &inner_dom = dom_map.find(inner_)->second;
- CHECK(is_const_int(inner_dom->min, 0));
-
- auto &outer_ = split->outer;
- CHECK(dom_map.count(outer_));
- auto &outer_dom = dom_map.find(outer_)->second;
- CHECK(is_const_int(outer_dom->min, 0));
-
- inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type);
- outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
- }
-
- Stmt VisitStmt_(const ForNode *op) final {
- if (op->loop_var.get() == parent) {
- std::unordered_map<const VarNode *, PrimExpr> rmap;
- rmap[op->loop_var.get()] = inner + outer * factor;
- Stmt ret = ir::Substitute(op->body, rmap);
- PrimExpr cond = likely(outer * factor < (op->extent - inner));
- ret = IfThenElseNode::make(cond, ret);
- ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent,
- IterVarTypeToForType(inner->iter_type), op->device_api, ret);
- ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent,
- IterVarTypeToForType(outer->iter_type), op->device_api, ret);
- splitted = true;
- return ret;
- }
- return StmtExprMutator::VisitStmt_(op);
- }
- };
-
- class LoopFuser : public StmtExprMutator {
- const IterVar &parent;
- const VarNode *inner;
- const VarNode *outer;
- bool under_outer;
- PrimExpr extent;
-
- public:
- bool fused;
- explicit LoopFuser(const FuseNode *fuse_)
- : parent(fuse_->fused), inner(fuse_->inner->var.get()),
- outer(fuse_->outer->var.get()), under_outer(false),
- extent(0), fused(false) {}
-
- // TODO(@were): Handle imperfect loops
- Stmt VisitStmt_(const ForNode* op) final {
- if (op->loop_var.get() == inner) {
- CHECK(under_outer);
- std::unordered_map<const VarNode *, PrimExpr> rmap;
- rmap[op->loop_var.get()] = indexmod(parent, op->extent);
- extent = op->extent;
- fused = true;
- return ir::Substitute(op->body, rmap);
- } else if (op->loop_var.get() == outer) {
- under_outer = true;
- Stmt body = this->VisitStmt(op->body);
- std::unordered_map<const VarNode *, PrimExpr> rmap;
- rmap[op->loop_var.get()] = indexdiv(parent, extent);
- body = ir::Substitute(body, rmap);
- under_outer = false;
- return ForNode::make(parent->var, PrimExpr(0), extent * op->extent,
- op->for_type, op->device_api, body);
- } else if (under_outer) {
- Stmt body = this->VisitStmt(op->body);
- std::unordered_map<const VarNode *, PrimExpr> rmap;
- rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
- body = ir::Substitute(body, rmap);
- extent = extent * op->extent;
- return body;
- }
- return StmtExprMutator::VisitStmt_(op);
- }
- };
-
- for (auto &rel : stage->relations) {
- if (const SplitNode *split = rel.as<SplitNode>()) {
- LoopSpliter Spliter(split, dom_map);
- stmt = Spliter(stmt);
- CHECK(Spliter.splitted);
- } else if (const FuseNode *fuse = rel.as<FuseNode>()) {
- LoopFuser Fuser(fuse);
- stmt = Fuser(stmt);
- CHECK(Fuser.fused);
- }
- }
-
- return stmt;
-}
-
-Stmt ApplyLoopAnnotations(const Stage &stage,
- const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
- class LoopAnnotator : public StmtMutator {
- const VarNode *var;
- const IterVarAttr &attr;
-
- public:
- LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
-
- Stmt VisitStmt_(const ForNode *op) final {
- if (op->loop_var.get() == var) {
- if (attr->bind_thread.defined()) {
- const auto &iter_var = attr->bind_thread;
- if (iter_var->dom.defined()) {
- CHECK(is_const_int(iter_var->dom->min, 0));
- CHECK(Equal(iter_var->dom->extent, op->extent))
- << "Thread extent and loop extent mismatch!\n";
- }
- std::unordered_map<const VarNode *, PrimExpr> rmap;
- rmap[op->loop_var.get()] = iter_var;
- Stmt body = ir::Substitute(op->body, rmap);
- return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
- } else {
- return ForNode::make(op->loop_var, op->min, op->extent,
- IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
- }
- }
- return StmtMutator::VisitStmt_(op);
- }
- };
-
- for (auto &iter_var : stage->leaf_iter_vars) {
- bool need_change = false;
- int found = 0;
-
- const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
- const VarNode *var = actual->var.get();
- ForType expected = IterVarTypeToForType(iter_var->iter_type);
- IterVarAttr attr;
- if (stage->iter_var_attrs.count(iter_var)) {
- attr = stage->iter_var_attrs[iter_var];
- expected = IterVarTypeToForType(attr->iter_type);
- }
-
- PostOrderVisit(stmt,
- [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
- if (const ForNode *op = node.as<ForNode>()) {
- if (op->loop_var.get() == var) {
- ++found;
- need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
- }
- }
- });
-
- CHECK_EQ(found, 1) << " iter var should be found exactly once!";
- if (need_change) {
- stmt = LoopAnnotator(var, attr)(std::move(stmt));
- }
- }
- return stmt;
-}
-
-Stmt ApplyLoopOrder(const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map,
- const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
- std::vector<const VarNode*> current_order;
- PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) {
- if (const ForNode *op = node.as<ForNode>())
- current_order.push_back(op->loop_var.get());
- });
- std::reverse(current_order.begin(), current_order.end());
- auto &required_ord = stage->leaf_iter_vars;
- CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
- std::unordered_map<const VarNode *, IterVar> reorder;
- bool need_reorder = false;
- for (size_t i = 0; i < current_order.size(); ++i) {
- auto ¤t = current_order[i];
- const IterVar &iter_var = required_ord[i];
- const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
- CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n";
- reorder[current] = required;
- if (current != required->var.get()) {
- need_reorder = true;
- }
- }
-
- class LoopReorder : public StmtMutator {
- const Stage &stage;
- const std::unordered_map<IterVar, Range> &dom_map;
- const std::unordered_map<const VarNode *, IterVar> &reorder;
-
- public:
- LoopReorder(const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map,
- const std::unordered_map<const VarNode*, IterVar> &reorder)
- : stage(stage), dom_map(dom_map), reorder(reorder) {}
-
- Stmt VisitStmt_(const ForNode* op) final {
- // Reorder from in to out
- Stmt body_ = this->VisitStmt(op->body);
- CHECK(reorder.count(op->loop_var.get()));
- auto target = reorder.find(op->loop_var.get())->second;
- if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
- return GetRef<Stmt>(op);
- const Stmt &body = op->body.same_as(body_) ? op->body : body_;
- ForType for_type = IterVarTypeToForType(target->iter_type);
- if (stage->iter_var_attrs.count(target)) {
- for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
- }
- const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
- return ForNode::make(target->var, range->min, range->extent,
- for_type, DeviceAPI::None, body);
- }
- };
-
- if (need_reorder)
- return LoopReorder(stage, dom_map, reorder)(stmt);
-
- return stmt;
-}
-
-Stmt ApplySchedule(const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
- // TODO(@were): Eliminate loop rebase in script parser and move the burden here
- // Gather rebased variables
- std::unordered_map<IterVar, IterVar> rebased;
- for (auto rel : stage->relations) {
- if (const auto* rebase = rel.as<RebaseNode>()) {
- rebased[rebase->rebased] = rebase->parent;
- CHECK(rebase->parent->dom.defined());
- CHECK(dom_map.count(rebase->rebased));
- }
- }
- stmt = ApplyLoopShapes(stage, dom_map, stmt);
- stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt);
- stmt = ApplyLoopAnnotations(stage, rebased, stmt);
- return stmt;
-}
-
-std::vector<IterVar> GatherLoopVars(Stmt stmt) {
- // TODO(@were): Write a comprehensive pass to analyze iter var types
- std::vector<IterVar> res_;
- PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
- if (const ForNode *op = node.as<ForNode>()) {
- Var loop_var(op->loop_var);
- Range dom = Range::make_by_min_extent(op->min, op->extent);
- res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type)));
- }
- });
- std::reverse(res_.begin(), res_.end());
- return res_;
-}
-
-// replacer to replace tensors' usage in Provide
-class ProviderReplacer : public ir::StmtMutator {
- public:
- explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
- : vmap_(vmap) {}
-
- Stmt VisitStmt_(const ir::ProvideNode* op) final {
- Tensor t = Downcast<Operation>(op->func).output(op->value_index);
- auto it = vmap_.find(t);
- if (it != vmap_.end()) {
- Stmt ret = ir::ProvideNode::make(
- it->second->op, it->second->value_index, op->value, op->args);
- found = true;
- return this->VisitStmt(ret);
- }
- return StmtMutator::VisitStmt_(op);
- }
-
- // whether it is found.
- bool found{false};
-
- private:
- const std::unordered_map<Tensor, Tensor> &vmap_;
-};
-
-Stmt ReplaceProvideTensor(Stmt stmt,
- const std::unordered_map<Tensor, Tensor> &replace) {
- ProviderReplacer repl(replace);
- Stmt ret = repl(stmt);
- return repl.found ? ret : stmt;
-}
-} // namespace op
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Helper utilities to implement hybrid_op.
- * \file hybrid_op.h
- */
-#ifndef TVM_OP_HYBRID_OP_H_
-#define TVM_OP_HYBRID_OP_H_
-
-#include <tvm/expr.h>
-#include <tvm/schedule.h>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-#include "../pass/ir_util.h"
-#include "../pass/arg_binder.h"
-#include "../schedule/message_passing.h"
-
-
-namespace tvm {
-namespace op {
-
-/*!
- * \brief Find all the iteration variables in the given statement body.
- * \param stmt The body to be inspected.
- */
-std::vector<IterVar> GatherLoopVars(Stmt stmt);
-
-/*!
- * \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
- * \param stmt The statement to be processed.
- * \param replace The replacement rule.
- */
-Stmt ReplaceProvideTensor(Stmt stmt,
- const std::unordered_map<Tensor, Tensor>& replace);
-
-/*!
- * \brief Apply the schedule manipulation on the function body.
- * \param stmt The statement to be processed.
- * \param dom_map The extents of the iterative variables may be used.
- * \param stage The schedule information to be applied.
- */
-Stmt ApplySchedule(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
-
-/*!
- * \brief Apply loop splits and fuses in the schedule on the function body.
- * \param stage The schedule information to be applied.
- * \param dom_map The extents of the iterative variables may be used.
- * \param stmt The statement to be processed.
- */
-Stmt ApplyLoopShapes(const Stage &stage,
- const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
-
-
-/*!
- * \brief Apply loop annotation in the schedule on the function body.
- * \param stage The schedule information to be applied.
- * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
- * \param stmt The statement to be processed.
- */
-Stmt ApplyLoopAnnotations(const Stage &stage,
- const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt);
-
-/*!
- * \brief Apply loop order in the schedule on the function body.
- * \param stage The schedule information to be applied.
- * \param dom_map The extents of the iterative variables may be used.
- * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
- * \param stmt The statement to be processed.
- */
-Stmt ApplyLoopOrder(const Stage &stage,
- const std::unordered_map<IterVar, Range> &dom_map,
- const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt);
-
-} // namespace op
-} // namespace tvm
-
-#endif // TVM_OP_HYBRID_OP_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Utility to make loop nest.
- * \file op_util.cc
- */
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/operation.h>
-#include <string>
-#include "op_util.h"
-#include "../schedule/message_passing.h"
-#include "../arith/compute_expr.h"
-
-namespace tvm {
-namespace op {
-
-using namespace arith;
-using namespace ir;
-
-std::vector<std::vector<Stmt> >
-MakeLoopNest(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- size_t begin_iter_pos,
- bool new_loop_var,
- const std::unordered_set<IterVar>& skip_iter,
- std::unordered_map<IterVar, PrimExpr>* p_value_map,
- bool debug_keep_trivial_loop) {
- auto leaf_iter_vars = stage->leaf_iter_vars;
- Stmt no_op = EvaluateNode::make(0);
- // create the loop nest
- std::vector<std::vector<Stmt> > nest;
- nest.resize(leaf_iter_vars.size() + 1);
- std::unordered_map<IterVar, PrimExpr>& value_map = *p_value_map;
-
- for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
- auto iv = leaf_iter_vars[i];
- if (skip_iter.count(iv) || iv->iter_type == kOpaque) {
- // skip this iteration.
- value_map[iv] = iv->var;
- continue;
- }
- // Bind iv could be another thread.
- IterVar bind_iv = iv;
- if (stage->iter_var_attrs.count(iv)) {
- IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
- if (bind_thread.defined()) bind_iv = bind_thread;
- }
-
- Range dom = dom_map.at(iv);
-
- // initialize the offset and loop_level
- Var var = bind_iv->var;
-
- // Mark the iter var in the IR, to remember the point
- if (bind_iv->thread_tag.length() == 0) {
- // Only generate new loop if we're not bound to a thread.
- if (new_loop_var) {
- var = Var(iv->var->name_hint + ".init", bind_iv->var.dtype());
- }
-
- ForType for_type = ForType::Serial;
- IterVarAttr it_attr;
- if (stage->iter_var_attrs.count(iv)) {
- it_attr = stage->iter_var_attrs[iv];
- }
- if (it_attr.defined()) {
- switch (it_attr->iter_type) {
- case kUnrolled: for_type = ForType::Unrolled; break;
- case kVectorized: for_type = ForType::Vectorized; break;
- case kParallelized: for_type = ForType::Parallel; break;
- case kDataPar: break;
- case kTensorized: break;
- default: LOG(FATAL) << "Unknown iter type"
- << it_attr->iter_type
- << " in the iter_var_attrs";
- }
- CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
- for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
- const std::string& pkey = it_attr->pragma_keys[k].as<StringImmNode>()->value;
- PrimExpr pvalue = it_attr->pragma_values[k];
- if (!pvalue.defined()) {
- pvalue = make_const(DataType::Int(32), 1);
- }
- nest[i + 1].emplace_back(
- AttrStmtNode::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
- }
- }
- if (!debug_keep_trivial_loop && is_one(dom->extent)) {
- nest[i + 1].emplace_back(
- LetStmtNode::make(var, dom->min, no_op));
- value_map[iv] = dom->min;
- } else if (is_zero(dom->min)) {
- nest[i + 1].emplace_back(
- ForNode::make(var, 0, dom->extent,
- for_type, DeviceAPI::None, no_op));
- value_map[iv] = var;
- } else {
- Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
- nest[i + 1].emplace_back(
- ForNode::make(idx, 0, dom->extent,
- for_type, DeviceAPI::None, no_op));
- PrimExpr new_value = dom->min + idx;
- value_map[iv] = new_value;
- nest[i + 1].emplace_back(
- LetStmtNode::make(var, new_value, no_op));
- }
- if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
- CHECK(!is_one(dom->extent))
- << "Cannot prefetch on trivial loop with extent=1";
- CHECK_EQ(it_attr->prefetch_data.size(),
- it_attr->prefetch_offset.size());
- for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
- nest[i + 1].emplace_back(
- AttrStmtNode::make(it_attr->prefetch_data[j],
- ir::attr::prefetch_scope,
- it_attr->prefetch_offset[j], no_op));
- }
- }
- } else if (bind_iv->thread_tag == "vthread" ||
- bind_iv->thread_tag == "cthread") {
- // virtual thread
- // Always restrict threaded IterVar to starts from 0.
- CHECK(is_zero(dom->min));
- CHECK(is_positive_const(dom->extent));
- // annotate the extent of the IterVar
- nest[i + 1].emplace_back(
- AttrStmtNode::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
- value_map[iv] = var;
- } else if (bind_iv->thread_tag == "pipeline") {
- // pipeline marker.
- CHECK(is_zero(dom->min));
- CHECK(is_one(dom->extent));
- // annotate the extent of the IterVar
- nest[i + 1].emplace_back(
- AttrStmtNode::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
- value_map[iv] = dom->min;
- } else {
- // Always restrict threaded IterVar to starts from 0.
- CHECK(is_zero(dom->min));
- // annotate the extent of the IterVar
- nest[i + 1].emplace_back(
- AttrStmtNode::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
- if (!debug_keep_trivial_loop && is_one(dom->extent)) {
- value_map[iv] = dom->min;
- } else {
- value_map[iv] = var;
- }
- }
- // annotate the extent of the IterVar
- if (!new_loop_var) {
- nest[i + 1].emplace_back(
- AttrStmtNode::make(iv, attr::loop_scope, iv->var, no_op));
- }
- }
- // message passing to get offset of root iter vars.
- schedule::PassUpIndex(stage, dom_map, &value_map);
- return nest;
-}
-
-std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates) {
- Stmt no_op = EvaluateNode::make(0);
- std::vector<Stmt> nest;
- for (const PrimExpr& cond : predicates) {
- nest.emplace_back(IfThenElseNode::make(cond, no_op));
- }
- return nest;
-}
-
-// replacer to replace tensors
-class TensorReplacer : public ir::StmtExprMutator {
- public:
- explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
- : vmap_(vmap) {}
-
- PrimExpr VisitExpr_(const ir::CallNode* op) final {
- if (op->call_type == ir::CallNode::Halide) {
- Tensor t = Downcast<Operation>(op->func).output(op->value_index);
- auto it = vmap_.find(t);
- if (it != vmap_.end()) {
- PrimExpr ret = ir::CallNode::make(
- op->dtype, it->second->op->name, op->args,
- op->call_type, it->second->op, it->second->value_index);
- found = true;
- return this->VisitExpr(ret);
- }
- }
- return StmtExprMutator::VisitExpr_(op);
- }
-
- // whether it is found.
- bool found{false};
-
- private:
- const std::unordered_map<Tensor, Tensor>& vmap_;
-};
-
-Stmt ReplaceTensor(Stmt stmt,
- const std::unordered_map<Tensor, Tensor>& replace) {
- TensorReplacer repl(replace);
- Stmt ret = repl(stmt);
- return repl.found ? ret : stmt;
-}
-PrimExpr ReplaceTensor(PrimExpr expr,
- const std::unordered_map<Tensor, Tensor>& replace) {
- TensorReplacer repl(replace);
- PrimExpr ret = repl(expr);
- return repl.found ? ret : expr;
-}
-
-
-Stmt Substitute(Stmt s,
- const std::unordered_map<IterVar, PrimExpr>& value_map) {
- std::unordered_map<const VarNode*, PrimExpr> init;
- for (const auto& kv : value_map) {
- init[kv.first->var.get()] = kv.second;
- }
- return ir::Substitute(s, init);
-}
-
-IterVarType ForTypeToIterVarType(ir::ForType for_type) {
- switch (for_type) {
- case ForType::Serial:
- return kDataPar;
- case ForType::Parallel:
- return kParallelized;
- case ForType::Vectorized:
- return kVectorized;
- case ForType::Unrolled:
- return kUnrolled;
- default:
- return kDataPar;
- }
-}
-
-ir::ForType IterVarTypeToForType(IterVarType iter_type) {
- switch (iter_type) {
- case kDataPar:
- return ForType::Serial;
- case kParallelized:
- return ForType::Parallel;
- case kVectorized:
- return ForType::Vectorized;
- case kUnrolled:
- return ForType::Unrolled;
- default:
- return ForType::Serial;
- }
-}
-
-} // namespace op
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file op_util.h
- * \brief Common utility used in operator construction.
- */
-#ifndef TVM_OP_OP_UTIL_H_
-#define TVM_OP_OP_UTIL_H_
-
-#include <tvm/expr.h>
-#include <tvm/schedule.h>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-#include "../pass/ir_util.h"
-#include "../pass/arg_binder.h"
-#include "../schedule/message_passing.h"
-
-namespace tvm {
-namespace op {
-
-using ir::MergeNest;
-
-/*!
- * \brief Build loop nest for stage.
- *
- * \param stage The stage to create a loop nest.
- * \param dom_map The range of each iter var.
- * \param begin_iter_pos The beginning position of leaf_iter_vars to generate loop.
- * \param new_loop_var Whether create new loop variable.
- * \param skip_iter Whether skip certain iteration.
- * \param p_value_map The result value of each IterVar.
- * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
- */
-std::vector<std::vector<Stmt> >
-MakeLoopNest(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- size_t begin_iter_pos,
- bool new_loop_var,
- const std::unordered_set<IterVar>& skip_iter,
- std::unordered_map<IterVar, PrimExpr>* p_value_map,
- bool debug_keep_trivial_loop);
-
-/*!
- * \brief Create a nest of if checking the predicates.
- *
- * \param predicates The predicates to be checked.
- * \return List of If nest that checks the predicates.
- */
-std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates);
-
-/*!
- * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
- * \param stmt The statement to be processed.
- * \param replace The replacement rule.
- */
-Stmt ReplaceTensor(Stmt stmt,
- const std::unordered_map<Tensor, Tensor>& replace);
-/*!
- * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
- * \param expr The expression to be processed.
- * \param replace The replacement rule.
- */
-PrimExpr ReplaceTensor(PrimExpr expr,
- const std::unordered_map<Tensor, Tensor>& replace);
-
-/*!
- * \brief Substitute the variables of stmt by value map.
- * \param stmt the statment
- * \param value_map The value map.
- * \return Substituted result.
- */
-Stmt Substitute(Stmt stmt,
- const std::unordered_map<IterVar, PrimExpr>& value_map);
-
-/*!
- * \brief Converts Halide ForType to its corresponding IterVarType
- * \param for_type The ForType to be converted
- */
-IterVarType ForTypeToIterVarType(ir::ForType for_type);
-
-/*!
- * \brief Converts IterVarType to its corresponding Halide ForType
- * \param iter_type The IterVarType to be converted
- */
-ir::ForType IterVarTypeToForType(IterVarType iter_type);
-
-} // namespace op
-} // namespace tvm
-#endif // TVM_OP_OP_UTIL_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Placeholder op.
- * \file placeholder_op.cc
- */
-#include <tvm/operation.h>
-
-namespace tvm {
-
-// PlaceholderOpNode
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const PlaceholderOpNode*>(node.get());
- p->stream << "placeholder(" << op->name << ", " << op << ")";
-});
-
-TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
-
-int PlaceholderOpNode::num_outputs() const {
- return 1;
-}
-
-Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
- return {};
-}
-
-DataType PlaceholderOpNode::output_dtype(size_t i) const {
- CHECK_EQ(i, 0U);
- return dtype;
-}
-
-Array<PrimExpr> PlaceholderOpNode::output_shape(size_t i) const {
- CHECK_EQ(i, 0U);
- return shape;
-}
-
-Operation PlaceholderOpNode::make(std::string name,
- Array<PrimExpr> shape,
- DataType dtype) {
- auto n = make_object<PlaceholderOpNode>();
- n->name = name;
- n->shape = shape;
- n->dtype = dtype;
- return Operation(n);
-}
-
-Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
- return PlaceholderOpNode::make(name, shape, dtype).output(0);
-}
-
-Array<Tensor> PlaceholderOpNode::InputTensors() const {
- return {};
-}
-
-Operation PlaceholderOpNode::ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const {
- return self;
-}
-
-void PlaceholderOpNode::PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
-}
-
-void PlaceholderOpNode::GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const {
-}
-
-Stmt PlaceholderOpNode::BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& realize_map,
- const Stmt& body) const {
- return body;
-}
-
-Stmt PlaceholderOpNode::BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const {
- return Stmt();
-}
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Scan Operator.
- * \file scan_op.cc
- */
-#include <tvm/operation.h>
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include "op_util.h"
-#include "../schedule/graph.h"
-
-namespace tvm {
-
-using namespace ir;
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<ScanOpNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const ScanOpNode*>(node.get());
- p->stream << "scan(" << op->name << ", " << op << ")";
-});
-TVM_REGISTER_NODE_TYPE(ScanOpNode);
-
-inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) {
- return is_zero(ir::Simplify(lhs - rhs));
-}
-
-int ScanOpNode::num_outputs() const {
- return static_cast<int>(update.size());
-}
-Array<IterVar> ScanOpNode::root_iter_vars() const {
- Array<IterVar> ret{scan_axis};
- for (IterVar iv : spatial_axis_) {
- ret.push_back(iv);
- }
- return ret;
-}
-
-DataType ScanOpNode::output_dtype(size_t i) const {
- return update[i]->dtype;
-}
-
-Array<PrimExpr> ScanOpNode::output_shape(size_t i) const {
- CHECK_LT(i, state_placeholder.size());
- return state_placeholder[i]->shape;
-}
-
-Operation ScanOpNode::make(std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs,
- IterVar axis,
- Array<Tensor> init,
- Array<Tensor> update,
- Array<Tensor> state_placeholder,
- Array<Tensor> inputs) {
- if (!attrs.defined()) {
- attrs = Map<std::string, ObjectRef>();
- }
- auto n = make_object<ScanOpNode>();
- CHECK_EQ(init.size(), update.size());
- CHECK_EQ(init.size(), state_placeholder.size());
-
- for (size_t i = 0; i < init.size(); ++i) {
- CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
- CHECK_EQ(init[i]->dtype, update[i]->dtype);
- CHECK(prove_equal(init[i]->shape[0], axis->dom->min))
- << "init.shape[0] need to match scan_axis.dom.min";
- CHECK(prove_equal(
- state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
- << "state_placeholder.shape[0] need to match"
- << " scan_axis.dom.min + scan_axis.dom.extent";
- CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
- << "The dimension of init need to match state_placeholder";
- CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
- << "The update.ndim need to be state_placeholder.ndim - 1";
- for (size_t k = 0; k < update[i].ndim(); ++k) {
- CHECK(prove_equal(
- update[i]->shape[k], state_placeholder[i]->shape[k]));
- if (k != 0) {
- // setup spatial axis
- std::ostringstream spatial_name;
- spatial_name << name << ".out" << i << ".i" << k;
- n->spatial_axis_.push_back(
- IterVarNode::make(
- Range::make_by_min_extent(0, update[i]->shape[k]),
- Var(spatial_name.str()), kOpaque));
- }
- }
-
- for (size_t k = 1; k < init[i].ndim(); ++k) {
- CHECK(prove_equal(
- init[i]->shape[k], state_placeholder[i]->shape[k]));
- }
- }
- n->name = std::move(name);
- n->tag = std::move(tag);
- n->attrs = std::move(attrs);
- n->scan_axis = std::move(axis);
- n->init = std::move(init);
- n->update = std::move(update);
- n->state_placeholder = std::move(state_placeholder);
- n->inputs = std::move(inputs);
- return Operation(n);
-}
-
-Array<Tensor> scan(Array<Tensor> init,
- Array<Tensor> update,
- Array<Tensor> state_placeholder,
- Array<Tensor> inputs,
- std::string name,
- std::string tag,
- Map<std::string, ObjectRef> attrs) {
- IterVar scan_axis =
- IterVarNode::make(
- Range::make_by_min_extent(
- init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
- Var(name + ".idx"), kOrdered);
- Operation op = ScanOpNode::make(
- name, tag, attrs, scan_axis,
- init, update, state_placeholder, inputs);
- Array<Tensor> res;
- for (int i = 0; i < op->num_outputs(); ++i) {
- res.push_back(op.output(i));
- }
- return res;
-}
-
-Array<Tensor> ScanOpNode::InputTensors() const {
- Array<Tensor> ret;
- for (Tensor t : init) {
- ret.push_back(t);
- }
- for (Tensor t : update) {
- ret.push_back(t);
- }
- return ret;
-}
-
-Operation ScanOpNode::ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const {
- CHECK_EQ(self.operator->(), this);
- auto n = make_object<ScanOpNode>(*this);
- for (size_t i = 0; i < n->init.size(); ++i) {
- if (rmap.count(n->init[i])) {
- n->init.Set(i, rmap.at(n->init[i]));
- }
- if (rmap.count(n->update[i])) {
- n->update.Set(i, rmap.at(n->update[i]));
- }
- }
- if (!n->init.same_as(init) ||
- !n->update.same_as(update)) {
- return Operation(n);
- } else {
- return self;
- }
-}
-
-void ScanOpNode::PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
- CHECK_EQ(self.operator->(), this);
- for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) {
- TensorDom* init_dom = nullptr;
- TensorDom* update_dom = nullptr;
- if (out_dom_map->count(this->init[i])) {
- init_dom = &out_dom_map->at(this->init[i]);
- }
- if (out_dom_map->count(this->update[i])) {
- update_dom = &out_dom_map->at(this->update[i]);
- }
- // first dimension, always needed.
- if (init_dom) {
- init_dom->data[0].push_back(IntSet::range(
- Range::make_by_min_extent(0, this->init[i]->shape[0])));
- }
- if (update_dom) {
- update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
- }
- // The update dimensions
- for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
- IterVar sp_ax = this->spatial_axis_[sp_idx];
- if (init_dom) {
- init_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
- }
- if (update_dom) {
- update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
- }
- }
- }
-}
-
-void ScanOpNode::GatherBound(
- const Operation& self,
- const std::unordered_map<Tensor, TensorDom>& tensor_dom,
- std::unordered_map<IterVar, Range>* out_dom_map) const {
- CHECK_EQ(self.operator->(), this);
- using namespace schedule;
- CHECK(!out_dom_map->count(this->scan_axis));
- std::vector<Tensor> output(this->num_outputs());
- for (size_t i = 0; i < output.size(); ++i) {
- output[i] = self.output(i);
- }
- // Update for time axis.
- std::vector<IntSet> time_dom;
- for (size_t i = 0; i < output.size(); ++i) {
- const TensorDom& d = tensor_dom.at(output[i]);
- time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
- }
- CHECK(!out_dom_map->count(this->scan_axis));
- Range sdom = this->scan_axis->dom;
- Range r = arith::Union(time_dom).cover_range(sdom);
- (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
- sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
- Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self);
- // Update for spatial axis.
- size_t sp_idx = 0;
- for (size_t i = 0; i < output.size(); ++i) {
- const TensorDom& d = tensor_dom.at(output[i]);
- for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
- IterVar sp_ax = this->spatial_axis_[sp_idx];
- CHECK(!out_dom_map->count(sp_ax));
- CHECK(fix_pt.count(sp_ax));
- if (fix_pt[sp_ax].as<ir::IntImmNode>()->value) {
- // fix point, we can slice it.
- (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom);
- } else {
- // not a fix point, need to include everything.
- (*out_dom_map)[sp_ax] = sp_ax->dom;
- }
- }
- }
-}
-
-Stmt ScanOpNode::BuildRealize(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const Stmt& body) const {
- CHECK_EQ(stage->op.get(), this);
- Range sdom = dom_map.at(this->scan_axis);
- Range tdom = Range::make_by_min_extent(
- 0, ir::Simplify(sdom->extent + sdom->min));
- Stmt ret = body;
- size_t sp_idx = 0;
- for (size_t i = 0; i < update.size(); ++i) {
- Tensor t = stage->op.output(i);
- CHECK_EQ(static_cast<size_t>(t->value_index), i);
- Region bounds;
- bounds.push_back(tdom);
- for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
- IterVar sp_ax = this->spatial_axis_[sp_idx];
- bounds.push_back(dom_map.at(sp_ax));
- }
- ret = ir::RealizeNode::make(t->op, t->value_index, t->dtype,
- bounds, const_true(), ret);
- }
- return ret;
-}
-
-Stmt ScanOpNode::BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const {
- CHECK_EQ(stage->op.operator->(), this);
- Stmt provide = AttrStmtNode::make(
- stage->op, attr::scan_update_scope, this->scan_axis->var,
- EvaluateNode::make(0));
- Stmt init = AttrStmtNode::make(
- stage->op, attr::scan_init_scope, 0,
- EvaluateNode::make(0));
- size_t begin_scan = 0;
- for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
- if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
- CHECK_EQ(begin_scan, i);
- begin_scan = i + 1;
- }
- }
- std::unordered_map<IterVar, PrimExpr> vmap;
- std::unordered_set<IterVar> empty;
- auto nest = op::MakeLoopNest(
- stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
- nest[begin_scan].push_back(init);
- nest.push_back(
- op::MakeIfNest(
- schedule::MakeBoundCheck(stage, dom_map, vmap, false, empty)));
- return MergeNest(nest, provide);
-}
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Tensor Compute Op.
- * \file tensor_compute_op.cc
- */
-#include <tvm/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <unordered_set>
-#include "./op_util.h"
-#include "./compute_op.h"
-#include "../arith/compute_expr.h"
-
-namespace tvm {
-using namespace ir;
-// TensorComputeOpNode
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const TensorComputeOpNode*>(node.get());
- p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
- });
-
-TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);
-
-int TensorComputeOpNode::num_outputs() const {
- return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
-}
-
-DataType TensorComputeOpNode::output_dtype(size_t i) const {
- return this->intrin->buffers[this->inputs.size() + i]->dtype;
-}
-
-Operation TensorComputeOpNode::make(std::string name,
- std::string tag,
- Array<IterVar> axis,
- Array<IterVar> reduce_axis,
- int schedulable_ndim,
- TensorIntrin intrin,
- Array<Tensor> tensors,
- Array<Region> regions,
- Array<PrimExpr> scalar_inputs) {
- auto n = make_object<TensorComputeOpNode>();
- n->name = std::move(name);
- n->tag = std::move(tag);
- n->axis = std::move(axis);
- n->reduce_axis = std::move(reduce_axis);
- n->schedulable_ndim = std::move(schedulable_ndim);
- n->intrin = std::move(intrin);
- n->inputs = std::move(tensors);
- n->input_regions = std::move(regions);
- n->scalar_inputs = std::move(scalar_inputs);
- return Operation(n);
-}
-
-Array<Tensor> TensorComputeOpNode::InputTensors() const {
- return inputs;
-}
-
-Operation TensorComputeOpNode::ReplaceInputs(
- const Operation& self,
- const std::unordered_map<Tensor, Tensor>& rmap) const {
- CHECK_EQ(self.operator->(), this);
- auto n = make_object<TensorComputeOpNode>(*this);
- auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
- intrin->body = op::ReplaceTensor(this->intrin->body, rmap);
- if (intrin->reduce_init.defined()) {
- intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap);
- }
- if (intrin->reduce_update.defined()) {
- intrin->reduce_update = op::ReplaceTensor(this->intrin->reduce_update, rmap);
- }
- for (size_t i = 0; i < n->inputs.size(); ++i) {
- Tensor t = n->inputs[i];
- if (rmap.count(t)) {
- n->inputs.Set(i, rmap.at(t));
- }
- }
-
- if (intrin->body.same_as(n->intrin->body) &&
- intrin->reduce_init.same_as(n->intrin->reduce_init) &&
- intrin->reduce_update.same_as(n->intrin->reduce_update) &&
- inputs.same_as(n->inputs)) {
- return self;
- } else {
- n->intrin = TensorIntrin(intrin);
- return Operation(n);
- }
-}
-
-void TensorComputeOpNode::PropBoundToInputs(
- const Operation& self,
- arith::Analyzer* analyzer,
- const std::unordered_map<const VarNode*, IntSet>& dom_map,
- std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
- for (size_t i = 0; i < this->inputs.size(); ++i) {
- Tensor t = this->inputs[i];
- Region region = input_regions[i];
-
- auto it = out_dom_map->find(t);
- if (it == out_dom_map->end()) continue;
- TensorDom& dom = it->second;
- for (size_t j = 0; j < t.ndim(); ++j) {
- dom.data[j].emplace_back(EvalSet(region[j], dom_map));
- }
- }
-}
-
-size_t TensorComputeOpNode::num_schedulable_dims() const {
- return schedulable_ndim;
-}
-
-Stmt TensorComputeOpNode::BuildProvide(
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) const {
- CHECK_EQ(stage->op.operator->(), this);
-
- // Start bind data.
- Stmt nop = EvaluateNode::make(0);
- std::vector<Stmt> input_bind_nest, output_bind_nest;
- Array<Tensor> inputs = this->InputTensors();
-
- // input binding
- size_t num_inputs = inputs.size();
- for (size_t i = 0; i < num_inputs; ++i) {
- Tensor tensor = inputs[i];
- Region region = this->input_regions[i];
- Buffer buffer = this->intrin->buffers[i];
- Array<ObjectRef> bind_spec{buffer, tensor};
-
- Array<PrimExpr> tuple;
- for (size_t i = 0; i < region.size(); ++i) {
- tuple.push_back(region[i]->min);
- tuple.push_back(region[i]->extent);
- }
- input_bind_nest.emplace_back(AttrStmtNode::make(
- bind_spec, ir::attr::buffer_bind_scope,
- CallNode::make(DataType::Handle(),
- ir::intrinsic::tvm_tuple,
- tuple, CallNode::Intrinsic), nop));
- }
-
- // output binding
- for (int i = 0; i < this->num_outputs(); ++i) {
- Tensor tensor = stage->op.output(i);
- Buffer buffer = this->intrin->buffers[num_inputs + i];
- Array<ObjectRef> bind_spec{buffer, tensor};
-
- Array<PrimExpr> tuple;
- for (size_t i = 0; i < this->axis.size(); ++i) {
- auto ivar = this->axis[i];
- if (i < static_cast<size_t>(this->schedulable_ndim)) {
- tuple.push_back(ivar->var);
- tuple.push_back(1);
- } else {
- Range dom = ivar->dom;
- tuple.push_back(dom->min);
- tuple.push_back(dom->extent);
- }
- }
-
- output_bind_nest.emplace_back(AttrStmtNode::make(
- bind_spec, ir::attr::buffer_bind_scope,
- CallNode::make(DataType::Handle(),
- ir::intrinsic::tvm_tuple,
- tuple, CallNode::Intrinsic), nop));
- }
-
- // Check variable remap
- std::unordered_map<const VarNode*, PrimExpr> vmap;
- ir::ArgBinder binder(&vmap);
-
- // Map the expressions passed in the call to the TensorIntrin, to the placeholder
- // variables
- Array<PrimExpr> user_expr = this->scalar_inputs;
- Array<Var> scalar_params = this->intrin->scalar_params;
- Array<PrimExpr> sp_expr;
- for (auto sp : scalar_params) {
- PrimExpr esp = sp;
- sp_expr.push_back(esp);
- }
- CHECK_EQ(sp_expr.size(), user_expr.size());
- // TODO(jdavies-huawei): what name should be used here?
- binder.BindArray(sp_expr, user_expr, this->name);
-
- size_t tloc = stage->leaf_iter_vars.size();
- ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
-
- if (this->reduce_axis.size() == 0) {
- std::vector<std::vector<Stmt> > nest(
- n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
- nest.emplace_back(op::MakeIfNest(n.main_predicates));
- CHECK_EQ(n.init_predicates.size(), 0U);
- CHECK(this->intrin->body.defined())
- << "Normal store op for intrin " << this << " is not defined";
- Stmt body = MergeNest(output_bind_nest, this->intrin->body);
- body = MergeNest(input_bind_nest, body);
- body = ir::Substitute(body, vmap);
- body = MergeNest(binder.asserts(), body);
- body = op::Substitute(body, n.main_vmap);
- Stmt ret = MergeNest(nest, body);
- return ret;
- } else {
- // Need to split reduction
- CHECK(this->intrin->reduce_update.defined())
- << "Reduction update op is not defined";
- // Need init and update steps
- CHECK_NE(this->reduce_axis.size(), 0U);
- std::vector<std::vector<Stmt> > common(
- n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
- std::vector<std::vector<Stmt> > update_nest(
- n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
- update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
-
- if (this->intrin->reduce_init.defined()) {
- // init nest
- std::vector<std::vector<Stmt> > init_nest(
- n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
- init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
- Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
- init = op::Substitute(init, n.init_vmap);
- init = MergeNest(init_nest, init);
- // The update
- Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
- update = MergeNest(input_bind_nest, update);
- update = ir::Substitute(update, vmap);
- update = MergeNest(binder.asserts(), update);
- update = op::Substitute(update, n.main_vmap);
- update = MergeNest(update_nest, update);
- return MergeNest(common, SeqStmt::Flatten(init, update));
- } else {
- // When init op is not available, use body op for reset in the first iter.
- CHECK(this->intrin->body.defined())
- << "Normal body op is not defined";
- Stmt update = TransformUpdate(stage, dom_map, n,
- this->intrin->body,
- this->intrin->reduce_update);
- update = MergeNest(output_bind_nest, update);
- update = MergeNest(input_bind_nest, update);
- update = ir::Substitute(update, vmap);
- update = MergeNest(binder.asserts(), update);
- update = op::Substitute(update, n.main_vmap);
- update = MergeNest(update_nest, update);
- return MergeNest(common, update);
- }
- }
-}
-
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Logics related to tensorize, used by ComputeOpNode.
- * \file tensorize.cc
- */
-#include <tvm/ir.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/ir_pass.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
-#include "op_util.h"
-#include "compute_op.h"
-#include "../schedule/message_passing.h"
-
-namespace tvm {
-
-using namespace ir;
-using namespace op;
-
-// Detect the region of input and output to be tensrized.
-// out_dom: the domain of root iter vars in output op
-// in_region: region of each input tensor.
-// return The location of the tensorized scope start.
-size_t InferTensorizeRegion(
- const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, Range>* out_dom,
- std::unordered_map<Tensor, Array<Range> >* in_region) {
- // Get the bound of the tensorized scope.
- bool found_point = false;
- size_t loc_scope = 0;
- std::unordered_map<IterVar, IntSet> up_state;
- // Loop over the leafs
- for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) {
- IterVar iv = stage->leaf_iter_vars[i - 1];
- CHECK(iv->iter_type == kDataPar ||
- iv->iter_type == kCommReduce);
- auto vit = dom_map.find(iv);
- CHECK(vit != dom_map.end());
- const Range& vrange = vit->second;
- if (is_one(vrange->extent)) {
- up_state[iv] = IntSet::single_point(vrange->min);
- } else if (found_point) {
- CHECK(is_zero(vrange->min));
- up_state[iv] = IntSet::single_point(iv->var);
- } else {
- up_state[iv] = IntSet::range(vrange);
- }
- auto iit = stage->iter_var_attrs.find(iv);
- if (iit != stage->iter_var_attrs.end()) {
- const IterVarAttr& attr = (*iit).second;
- if (!found_point) {
- CHECK(!attr->bind_thread.defined())
- << "Do not allow thread in tensorize scope";
- }
- if (attr->iter_type == kTensorized) {
- CHECK(!found_point) << "Do not allow two tensorized point";
- found_point = true;
- loc_scope = i - 1;
- }
- }
- }
- CHECK(found_point);
- // Get domain of the tensorized scope.
- schedule::PassUpDomain(stage, dom_map, &up_state);
- // Get domains if inputs
- std::unordered_map<Tensor, TensorDom> in_dom;
- std::unordered_map<const VarNode*, IntSet> temp_dmap;
- arith::Analyzer analyzer;
- Array<Tensor> inputs = self->InputTensors();
- for (Tensor t : inputs) {
- in_dom.emplace(t, TensorDom(t.ndim()));
- }
- for (IterVar iv : self->root_iter_vars()) {
- IntSet iset = up_state.at(iv);
- Range iv_range = iset.cover_range(dom_map.at(iv));
- (*out_dom)[iv] = iv_range;
- analyzer.Bind(iv->var, iv_range);
- temp_dmap[iv->var.get()] = iset;
- }
- // Input domains
- self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom);
- Range none;
- for (const auto& kv : in_dom) {
- Array<Range> vec;
- const Tensor& t = kv.first;
- for (size_t i = 0; i < t.ndim(); ++i) {
- Range r = arith::Union(kv.second.data.at(i)).cover_range(none);
- CHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t;
- vec.push_back(std::move(r));
- }
- (*in_region)[t] = std::move(vec);
- }
- return loc_scope;
-}
-
-void VerifyTensorizeLoopNest(const ComputeOpNode* self,
- const Stage& stage,
- const ComputeLoopNest& n,
- size_t tloc) {
- // Veirfication step.
- std::unordered_set<const VarNode*> banned;
- CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
- CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 ||
- n.init_nest.size() == 0);
- auto f_push_banned = [&banned](const Stmt& s) {
- if (const ForNode* op = s.as<ForNode>()) {
- banned.insert(op->loop_var.get());
- } else if (const AttrStmtNode* op = s.as<AttrStmtNode>()) {
- if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
- banned.insert(iv->var.get());
- }
- } else if (const LetStmtNode* op = s.as<LetStmtNode>()) {
- banned.insert(op->var.get());
- }
- };
- for (size_t i = tloc; i < stage->leaf_iter_vars.size(); ++i) {
- for (const Stmt& s : n.main_nest[i + 1]) {
- f_push_banned(s);
- }
- if (n.init_nest.size() != 0) {
- for (const Stmt& s : n.init_nest[i + 1]) {
- f_push_banned(s);
- }
- }
- }
- for (const PrimExpr& pred : n.main_predicates) {
- if (ir::ExprUseVar(pred, banned)) {
- LOG(FATAL) << "Tensorize failed, split condition "
- << pred << " relies on var defined inside tensorize scope";
- }
- }
- for (const PrimExpr& pred : n.init_predicates) {
- if (ir::ExprUseVar(pred, banned)) {
- LOG(FATAL) << "Tensorize failed, split condition "
- << pred << " relies on var defined inside tensorize scope";
- }
- }
-}
-
-// Remap the tensor placeholder, index and inline things.
-class TensorIntrinMatcher final : public StmtExprMutator {
- public:
- PrimExpr VisitExpr_(const CallNode* op) final {
- PrimExpr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<CallNode>();
- if (op->call_type == CallNode::Halide) {
- Tensor t = Downcast<Operation>(op->func).output(op->value_index);
- auto it = in_remap_.find(t);
- if (it != in_remap_.end()) {
- const InputEntry& e = it->second;
- CHECK_EQ(op->args.size(), e.region.size());
- Array<PrimExpr> args;
- for (size_t i = e.start; i < e.region.size(); ++i) {
- args.push_back(op->args[i] - e.region[i]->min);
- }
- return CallNode::make(
- op->dtype, e.tensor->op->name, args,
- op->call_type, e.tensor->op, e.tensor->value_index);
- }
- }
- return expr;
- }
-
- PrimExpr VisitExpr_(const VarNode* op) final {
- auto it = var_remap_.find(op);
- if (it != var_remap_.end()) {
- return it->second;
- } else {
- return GetRef<PrimExpr>(op);
- }
- }
-
- PrimExpr VisitExpr_(const ReduceNode* op) final {
- PrimExpr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<ReduceNode>();
- Array<IterVar> axis;
- for (size_t i = 0; i < op->axis.size(); ++i) {
- auto it = axis_remap_.find(op->axis[i]);
- if (it != axis_remap_.end()) {
- axis.push_back(it->second);
- }
- }
- return ReduceNode::make(
- op->combiner, op->source, axis, op->condition, op->value_index);
- }
-
- void Init(const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, Range>& out_dom,
- const std::unordered_map<Tensor, Array<Range> >& in_region,
- const TensorIntrin& intrin,
- Map<Var, Range>* compute_intrin_iter_space) {
- CHECK(self == stage->op.get());
-
- for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
- IterVar iv = stage->leaf_iter_vars[i];
- auto vit = dom_map.find(iv);
- if (vit != dom_map.end()) {
- const Range vrange = vit->second;
- compute_intrin_iter_space->Set(iv->var, vrange);
- }
- }
-
- // input remap.
- Array<Tensor> inputs = self->InputTensors();
- CHECK_EQ(inputs.size(), intrin->inputs.size());
- for (size_t i = 0; i < inputs.size(); ++i) {
- InputEntry e;
- e.tensor = intrin->inputs[i];
- e.region = Array<Range>(in_region.at(inputs[i]));
- CHECK_GE(e.region.size(), e.tensor.ndim());
- // Enable fuzzy matching, to match [1, n, m] to [n, m]
- e.start = e.region.size() - e.tensor.ndim();
- for (size_t j = 0; j < e.start; ++j) {
- auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space);
- CHECK(is_one(canonical_extent))
- << "Tensorize " << intrin->name << ":"
- << " Input dimension mismatch with tensor intrin "
- << " expected shape=" << e.tensor->shape
- << ", given region=" << e.region;
- }
- in_remap_[inputs[i]] = e;
- }
- // output remap
- const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
- CHECK(intrin_compute) << "Only support compute intrinsic for now";
- CHECK_GE(self->axis.size(), intrin_compute->axis.size())
- << "Tensorize: Output mismatch with tensor intrin ";
- // Enable fuzzy matching, to match [1, n, m] to [n, m]
- size_t axis_start = self->axis.size() - intrin_compute->axis.size();
- for (size_t i = 0; i < axis_start; ++i) {
- Range r = out_dom.at(self->axis[i]);
- CHECK(is_one(r->extent))
- << "Tensorize: Output mismatch with tensor intrin "
- << " intrin-dim=" << intrin_compute->axis.size()
- << ", tensorize-dim=" << self->axis.size();
- var_remap_[self->axis[i]->var.get()] = r->min;
- }
- // Assume we tensorize at regin axis i [min, min + extent)
- // The corresponding intrinsic axis is j [0, extent)
- // Remap index i to j + min
- for (size_t i = axis_start; i < self->axis.size(); ++i) {
- IterVar iv = self->axis[i];
- IterVar target_iv = intrin_compute->axis[i - axis_start];
- Range r = out_dom.at(iv);
- var_remap_[iv->var.get()] = target_iv->var + r->min;
- axis_remap_[iv] = target_iv;
- compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
- }
- // Remap reduction axis
- CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
- << "Tensorize: Reduction dimension mismatch with tensor intrin";
- axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
- for (size_t i = 0; i < axis_start; ++i) {
- Range r = out_dom.at(self->reduce_axis[i]);
- CHECK(is_one(r->extent))
- << "Tensorize: Reduction mismatch with tensor intrin "
- << " intrin-dim=" << intrin_compute->reduce_axis.size()
- << ", tensorize-dim=" << self->reduce_axis.size();
- var_remap_[self->reduce_axis[i]->var.get()] = r->min;
- }
- for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
- IterVar iv = self->reduce_axis[i];
- IterVar target_iv = intrin_compute->reduce_axis[i - axis_start];
- Range r = out_dom.at(iv);
- var_remap_[iv->var.get()] = target_iv->var + r->min;
- axis_remap_[iv] = target_iv;
- compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
- }
- }
-
- private:
- // Input entry
- struct InputEntry {
- Tensor tensor;
- size_t start;
- Array<Range> region;
- };
- // input data remap
- std::unordered_map<Tensor, InputEntry> in_remap_;
- // variable remap.
- std::unordered_map<const VarNode*, PrimExpr> var_remap_;
- // IterVar remap.
- std::unordered_map<IterVar, IterVar> axis_remap_;
-};
-
-// Try to match tensor dataflow of the stage with the intrinsic
-Array<PrimExpr> MatchTensorizeBody(
- const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, Range>& out_dom,
- const std::unordered_map<Tensor, Array<Range> >& in_region,
- const TensorIntrin& intrin,
- Map<Var, Range>* compute_intrin_iter_space) {
- TensorIntrinMatcher matcher;
- matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
- Array<PrimExpr> ret;
- for (PrimExpr expr : self->body) {
- ret.push_back(matcher(expr));
- }
- return ret;
-}
-
-void VerifyTensorizeBody(
- const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, Range>& out_dom,
- const std::unordered_map<Tensor, Array<Range> >& in_region,
- const TensorIntrin& intrin) {
- Map<Var, Range> compute_intrin_iter_space;
- Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
- &compute_intrin_iter_space);
- const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
- CHECK(intrin_compute) << "Only support compute intrinsic for now";
- CHECK_EQ(body.size(), intrin_compute->body.size())
- << "Tensorize failed: body size mismatch";
- for (size_t i = 0; i < body.size(); ++i) {
- PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space);
- lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
- PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
- rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
- if (lhs.dtype() != rhs.dtype()) {
- LOG(FATAL)
- << "Failed to match the data type with TensorIntrin "
- << intrin->name << "'s declaration "
- << " provided=" << lhs.dtype()
- << ", intrin=" << rhs.dtype();
- }
- CHECK(Equal(lhs, rhs))
- << "Failed to match the compute with TensorIntrin "
- << intrin->name << "'s declaration "
- << " provided= " << lhs
- << ", intrin= " << rhs;
- }
-}
-
-Stmt MakeTensorize(const ComputeOpNode* self,
- const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop) {
- std::unordered_map<IterVar, Range> out_dom;
- std::unordered_map<Tensor, Array<Range> > in_region;
- size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
- TensorIntrin intrin = stage->iter_var_attrs.at(
- stage->leaf_iter_vars[tloc])->tensor_intrin;
- CHECK(intrin.defined());
- ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
- VerifyTensorizeLoopNest(self, stage, n, tloc);
- VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
- // Start bind data.
- Stmt nop = EvaluateNode::make(0);
- std::vector<Stmt> input_bind_nest, output_bind_nest;
- Array<Tensor> inputs = self->InputTensors();
- CHECK_EQ(inputs.size(), intrin->inputs.size())
- << "Tensorize failed: input size mismatch ";
- // input binding
- for (size_t i = 0; i < intrin->inputs.size(); ++i) {
- Tensor tensor = inputs[i];
- Buffer buffer = intrin->buffers[i];
- Array<ObjectRef> bind_spec{buffer, tensor};
- auto it = in_region.find(tensor);
- CHECK(it != in_region.end());
- const Array<Range>& region = it->second;
- Array<PrimExpr> tuple;
- for (const Range r : region) {
- tuple.push_back(r->min);
- tuple.push_back(r->extent);
- }
- input_bind_nest.emplace_back(AttrStmtNode::make(
- bind_spec, ir::attr::buffer_bind_scope,
- CallNode::make(DataType::Handle(),
- ir::intrinsic::tvm_tuple,
- tuple, CallNode::Intrinsic), nop));
- }
- // output binding
- const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
- CHECK(intrin_compute) << "Only support compute intrinsic for now";
- CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size());
- CHECK_EQ(intrin_compute->body.size(), self->body.size());
- Array<PrimExpr> tuple;
- for (IterVar iv : self->axis) {
- auto it = out_dom.find(iv);
- CHECK(it != out_dom.end());
- tuple.push_back(it->second->min);
- tuple.push_back(it->second->extent);
- }
- for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) {
- Tensor tensor = stage->op.output(i - intrin->inputs.size());
- Buffer buffer = intrin->buffers[i];
- Array<ObjectRef> bind_spec{buffer, tensor};
- output_bind_nest.emplace_back(AttrStmtNode::make(
- bind_spec, ir::attr::buffer_bind_scope,
- CallNode::make(DataType::Handle(),
- ir::intrinsic::tvm_tuple,
- tuple, CallNode::Intrinsic), nop));
- }
- // Check variable remap
- std::unordered_map<const VarNode*, PrimExpr> vmap;
- ir::ArgBinder binder(&vmap);
- CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
- << "Tensorization fail: reduction axis size do not match";
- size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
- for (size_t i = 0; i < start; ++i) {
- IterVar iv = self->reduce_axis[i];
- auto it = out_dom.find(iv);
- CHECK(it != out_dom.end());
- CHECK(is_one(it->second->extent))
- << "Tensorization fail: reduction axis size do not match";
- }
- for (size_t i = start; i < self->reduce_axis.size(); ++i) {
- IterVar iv = self->reduce_axis[i];
- IterVar target = intrin_compute->reduce_axis[i - start];
- auto it = out_dom.find(iv);
- CHECK(it != out_dom.end());
- binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0),
- "tensir_intrin.reduction.min");
- binder.Bind(target->dom->extent, it->second->extent,
- "tensir_intrin.reduction.extent");
- }
- if (tloc <= n.num_common_loop) {
- // Do no need to split reduction
- std::vector<std::vector<Stmt> > nest(
- n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
- nest.emplace_back(op::MakeIfNest(n.main_predicates));
- CHECK_EQ(n.init_predicates.size(), 0U);
- CHECK(intrin->body.defined())
- << "Normal store op for intrin " << intrin << " is not defined";
- Stmt body = MergeNest(output_bind_nest, intrin->body);
- body = MergeNest(input_bind_nest, body);
- body = Substitute(body, vmap);
- body = MergeNest(binder.asserts(), body);
- body = Substitute(body, n.main_vmap);
- return MergeNest(nest, body);
- } else {
- // Need to split reduction
- CHECK(intrin->reduce_update.defined())
- << "Reduction update op for intrin " << intrin << " is not defined";
- // Need init and update steps
- CHECK_NE(self->reduce_axis.size(), 0U);
- std::vector<std::vector<Stmt> > common(
- n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
- std::vector<std::vector<Stmt> > update_nest(
- n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
- update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
-
- if (intrin->reduce_init.defined()) {
- // init nest
- std::vector<std::vector<Stmt> > init_nest(
- n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
- init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
- Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
- init = Substitute(init, n.init_vmap);
- init = MergeNest(init_nest, init);
- // The update
- Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
- update = MergeNest(input_bind_nest, update);
- update = Substitute(update, vmap);
- update = MergeNest(binder.asserts(), update);
- update = Substitute(update, n.main_vmap);
- update = MergeNest(update_nest, update);
- return MergeNest(common, SeqStmt::Flatten(init, update));
- } else {
- // When init op is not available, use body op for reset in the first iter.
- CHECK(intrin->body.defined())
- << "Normal body op for intrin " << intrin << " is not defined";
- Stmt update = TransformUpdate(stage, dom_map, n,
- intrin->body,
- intrin->reduce_update);
- update = MergeNest(output_bind_nest, update);
- update = MergeNest(input_bind_nest, update);
- update = Substitute(update, vmap);
- update = MergeNest(binder.asserts(), update);
- update = Substitute(update, n.main_vmap);
- update = MergeNest(update_nest, update);
- return MergeNest(common, update);
- }
- }
-}
-
-// Register functions for unittests
-TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- Stage stage = args[0];
- Map<IterVar, Range> dmap = args[1];
- std::unordered_map<IterVar, Range> out_dom;
- std::unordered_map<Tensor, Array<Range> > in_region;
- CHECK(stage->op.as<ComputeOpNode>());
- InferTensorizeRegion(stage->op.as<ComputeOpNode>(),
- stage,
- as_unordered_map(dmap),
- &out_dom, &in_region);
- *ret = Array<ObjectRef>{Map<IterVar, Range>(out_dom),
- Map<Tensor, Array<Range> >(in_region)};
- });
-
-TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- Stage stage = args[0];
- Map<IterVar, Range> out_dom = args[1];
- Map<Tensor, Array<Range> > in_region = args[2];
- TensorIntrin intrin = args[3];
- Map<Var, Range> vrange;
- CHECK(stage->op.as<ComputeOpNode>());
- *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
- stage,
- {{}},
- as_unordered_map(out_dom),
- as_unordered_map(in_region),
- intrin,
- &vrange);
- });
-} // namespace tvm
Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
if (op && op->attr_key == attr::prefetch_scope) {
- Tensor ts = Downcast<Tensor>(op->node);
+ top::Tensor ts = Downcast<top::Tensor>(op->node);
CHECK_NE(loop_nest_.size(), 0U);
Domain domain = DomainTouched(op->body, ts, true, false);
Region region;
#include <tvm/arith/analyzer.h>
#include <tvm/ir.h>
#include <tvm/expr.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_pass.h>
class StorageFlattener : public StmtExprMutator {
public:
- explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
+ explicit StorageFlattener(Map<top::Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes,
IRVisitorWithAnalyzer* bounded_analyzer)
: bounded_analyzer_(bounded_analyzer),
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::double_buffer_scope &&
- op->node->IsInstance<OperationNode>()) {
- Operation func = Downcast<Operation>(op->node);
+ op->node->IsInstance<top::OperationNode>()) {
+ auto func = Downcast<top::Operation>(op->node);
Stmt body = this->VisitStmt(op->body);
for (int i = 0; i < func->num_outputs(); ++i) {
TensorKey key{func, i};
} else if (op->attr_key == attr::buffer_bind_scope) {
return HandleBufferBindScope(op);
} else if (op->attr_key == attr::buffer_dim_align) {
- Tensor tensor = Downcast<Tensor>(op->node);
+ auto tensor = Downcast<top::Tensor>(op->node);
const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index};
Array<ObjectRef> arr = Downcast<Array<ObjectRef> > (op->node);
CHECK_EQ(arr.size(), 2U);
const BufferNode* buffer = arr[0].as<BufferNode>();
- const TensorNode* tensor = arr[1].as<TensorNode>();
+ const top::TensorNode* tensor = arr[1].as<top::TensorNode>();
const CallNode* tuple = op->value.as<CallNode>();
CHECK(buffer && tensor);
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
bool create_bound_attributes_{false};
};
-Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
+Stmt StorageFlatten(Stmt stmt, Map<top::Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes) {
IRVisitorWithAnalyzer bounded_analyzer;
bounded_analyzer(stmt);
// IR Passes for TensorCore CodeGen
#include <tvm/ir.h>
#include <tvm/expr.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/expr_operator.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
+using namespace top;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else if (op->attr_key == attr::buffer_dim_align) {
- Tensor tensor = Downcast<Tensor>(op->node);
+ top::Tensor tensor = Downcast<top::Tensor>(op->node);
const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
auto& vinfo = dim_align_[TensorKey{tensor->op, tensor->value_index}];
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
if (op->attr_key == attr::realize_scope) {
- auto node = op->node.as<OperationNode>();
+ auto node = op->node.as<top::OperationNode>();
if (node != nullptr) {
if (!frag_reg_.count(node->name)) {
return stmt;
buffer_node->offset_factor = 1;
Buffer buffer(buffer_node);
- ObjectPtr<TensorNode> tensor_node = make_object<TensorNode>();
+ ObjectPtr<top::TensorNode> tensor_node = make_object<top::TensorNode>();
tensor_node->value_index = key.value_index;
- tensor_node->op = Downcast<Operation>(key.f);
+ tensor_node->op = Downcast<top::Operation>(key.f);
tensor_node->shape = shape;
tensor_node->dtype = datatype;
Tensor tensor(tensor_node);
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
-#include <tvm/tensor.h>
+#include <tvm/top/tensor.h>
#include <unordered_map>
*/
#include "compile_engine.h"
-#include <tvm/schedule.h>
+#include <tvm/top/schedule.h>
#include <tvm/packed_func_ext.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
// The getter to get schedule from compile engine.
// Get schedule from functor.
class ScheduleGetter :
- public ExprFunctor<Array<Tensor>(const Expr&)> {
+ public ExprFunctor<Array<top::Tensor>(const Expr&)> {
public:
explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
- std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
+ std::pair<top::Schedule, CachedFunc> Create(const Function& prim_func) {
static auto fschedule =
Op::GetAttr<FTVMSchedule>("FTVMSchedule");
auto cache_node = make_object<CachedFuncNode>();
cache_node->target = target_;
for (Var param : prim_func->params) {
- Array<tvm::Tensor> inputs;
+ Array<tvm::top::Tensor> inputs;
if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
- tvm::Tensor tensor = tvm::placeholder(
+ tvm::top::Tensor tensor = tvm::top::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
const auto* ttype = field.as<TensorTypeNode>();
// TODO(@icemelon): Allow recursive tuple
CHECK(ttype != nullptr);
- tvm::Tensor tensor = tvm::placeholder(
+ tvm::top::Tensor tensor = tvm::top::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
// Hence schedule only non PlaceholderOp outputs.
- tvm::Array<Tensor> tensor_outs;
+ tvm::Array<top::Tensor> tensor_outs;
for (const auto& tensor : cache_node->outputs) {
- if (!tensor->op.as<PlaceholderOpNode>()) {
+ if (!tensor->op.as<top::PlaceholderOpNode>()) {
tensor_outs.push_back(tensor);
}
}
- Schedule schedule;
+ top::Schedule schedule;
// No need to register schedule for device copy op.
if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) {
schedule =
return std::make_pair(schedule, cfunc);
}
- Array<Tensor> VisitExpr(const Expr& expr) {
+ Array<top::Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
- Array<Tensor> res = ExprFunctor::VisitExpr(expr);
+ Array<top::Tensor> res = ExprFunctor::VisitExpr(expr);
memo_[expr] = res;
return res;
}
}
- Array<Tensor> VisitExpr_(const VarNode* op) final {
+ Array<top::Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint();
return {};
}
- Array<Tensor> VisitExpr_(const ConstantNode* op) final {
+ Array<top::Tensor> VisitExpr_(const ConstantNode* op) final {
CHECK(op->is_scalar());
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
- Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+ auto value = top::compute({}, [&](const Array<tvm::Var>&) {
if (dtype == DataType::Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == DataType::Int(64)) {
return {value};
}
- Array<Tensor> VisitExpr_(const CallNode* call_node) final {
+ Array<top::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fcompute =
Op::GetAttr<FTVMCompute>("FTVMCompute");
static auto fpattern =
Op::GetAttr<TOpPattern>("TOpPattern");
- Array<Tensor> inputs;
+ Array<top::Tensor> inputs;
int count_tuple = 0;
for (Expr arg : call_node->args) {
if (arg->checked_type().as<TupleTypeNode>()) {
++count_tuple;
}
- for (Tensor tensor : VisitExpr(arg)) {
+ for (top::Tensor tensor : VisitExpr(arg)) {
inputs.push_back(tensor);
}
}
CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
- Array<Tensor> outputs;
+ Array<top::Tensor> outputs;
// Skip fcompute for device copy operators as it is not registered.
if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->();
- outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype,
- Operation(), 0));
+ outputs.push_back(top::TensorNode::make(copy_input->shape, copy_input->dtype,
+ top::Operation(), 0));
} else {
outputs = fcompute[op](call_node->attrs, inputs,
call_node_type, target_);
return outputs;
}
- Array<Tensor> VisitExpr_(const FunctionNode* op) final {
+ Array<top::Tensor> VisitExpr_(const FunctionNode* op) final {
LOG(FATAL) << "Do not support sub function";
- return Array<Tensor>();
+ return Array<top::Tensor>();
}
- Array<Tensor> VisitExpr_(const LetNode* op) final {
- Array<Tensor> val = VisitExpr(op->value);
+ Array<top::Tensor> VisitExpr_(const LetNode* op) final {
+ Array<top::Tensor> val = VisitExpr(op->value);
CHECK(!memo_.count(op->var));
memo_[op->var] = val;
return VisitExpr(op->body);
}
- Array<Tensor> VisitExpr_(const TupleNode* op) final {
- Array<Tensor> fields;
+ Array<top::Tensor> VisitExpr_(const TupleNode* op) final {
+ Array<top::Tensor> fields;
for (Expr field : op->fields) {
CHECK(field->checked_type().as<TensorTypeNode>())
<< "Only allow Tuple of Tensor";
- Array<Tensor> res = VisitExpr(field);
+ Array<top::Tensor> res = VisitExpr(field);
CHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
}
return fields;
}
- Array<Tensor> VisitExpr_(const TupleGetItemNode* op) final {
+ Array<top::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
const auto* tuple_type = op->tuple->type_as<TupleTypeNode>();
- Array<Tensor> tuple = VisitExpr(op->tuple);
+ Array<top::Tensor> tuple = VisitExpr(op->tuple);
CHECK_EQ(tuple_type->fields.size(), tuple.size());
CHECK_GE(op->index, 0);
CHECK_LT(static_cast<size_t>(op->index), tuple.size());
Attrs master_attrs_;
int master_op_pattern_{0};
std::ostringstream readable_name_stream_;
- std::unordered_map<Expr, Array<Tensor>, ObjectHash, ObjectEqual> memo_;
- Array<Operation> scalars_;
+ std::unordered_map<Expr, Array<top::Tensor>, ObjectHash, ObjectEqual> memo_;
+ Array<top::Operation> scalars_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
};
// Creates shape function from functor.
-class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
+class MakeShapeFunc : public ExprFunctor<Array<top::Tensor>(const Expr&)> {
public:
MakeShapeFunc() {}
- std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
+ std::pair<top::Schedule, CachedFunc> Create(const Function& prim_func) {
for (auto param : prim_func->params) {
param_states_[param] = kNoNeed;
- Array<tvm::Tensor> data_inputs;
- Array<tvm::Tensor> shape_inputs;
+ Array<tvm::top::Tensor> data_inputs;
+ Array<tvm::top::Tensor> shape_inputs;
auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) {
// Add data placeholder
Shape shape = GetShape(ttype->shape);
- tvm::Tensor data_tensor = tvm::placeholder(shape, ttype->dtype);
+ tvm::top::Tensor data_tensor = tvm::top::placeholder(shape, ttype->dtype);
data_inputs.push_back(data_tensor);
// Add shape placeholder
int64_t ndim = shape.size();
if (ndim > 0) {
sshape.push_back(tvm::Integer(ndim));
}
- tvm::Tensor shape_tensor = tvm::placeholder(sshape, DataType::Int(64));
+ tvm::top::Tensor shape_tensor = tvm::top::placeholder(sshape, DataType::Int(64));
shape_inputs.push_back(shape_tensor);
};
CachedFunc cfunc(cache_node);
// generate schedule for shape func
- Array<Operation> out_ops;
+ Array<top::Operation> out_ops;
for (auto t : cache_node->outputs) {
out_ops.push_back(t->op);
}
- auto schedule = create_schedule(out_ops);
- tvm::schedule::AutoInlineInjective(schedule);
+ auto schedule = top::create_schedule(out_ops);
+ tvm::top::AutoInlineInjective(schedule);
for (const auto& scalar : scalars_) {
auto scalar_op = scalar->op;
if (schedule->Contain(scalar_op)) {
return std::make_pair(schedule, cfunc);
}
- Array<Tensor> VisitExpr(const Expr& expr) {
+ Array<top::Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
- Array<Tensor> res = ExprFunctor::VisitExpr(expr);
+ Array<top::Tensor> res = ExprFunctor::VisitExpr(expr);
if (expr.as<VarNode>() == nullptr) {
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
}
}
- Array<Tensor> VisitExpr_(const VarNode* var_node) final {
+ Array<top::Tensor> VisitExpr_(const VarNode* var_node) final {
auto var = GetRef<Var>(var_node);
auto it = param_states_.find(var);
if (it == param_states_.end()) {
}
}
- Array<Tensor> VisitExpr_(const ConstantNode* op) final {
+ Array<top::Tensor> VisitExpr_(const ConstantNode* op) final {
CHECK(data_dependants_.size());
CHECK(op->is_scalar());
bool data_dependant = data_dependants_.back();
if (data_dependant) {
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
- Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+ auto value = tvm::top::compute({}, [&](const Array<tvm::Var>&) {
if (dtype == DataType::Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == DataType::Int(64)) {
scalars_.push_back(value);
return {value};
} else {
- Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+ auto value = tvm::top::compute({}, [&](const Array<tvm::Var>&) {
return make_const(DataType::Int(64), 0);
}, "shape_const", topi::kBroadcast);
scalars_.push_back(value);
}
}
- Array<Tensor> VisitExpr_(const CallNode* call_node) final {
+ Array<top::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fshape_func = Op::GetAttr<FShapeFunc>("FShapeFunc");
static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>(
"TShapeDataDependant");
data_dependants_.push_back(tshape_data_dependant[op]);
// Visit all inputs
- Array<Tensor> inputs;
+ Array<top::Tensor> inputs;
int count_tuple = 0;
for (Expr arg : call_node->args) {
if (arg->checked_type().as<TupleTypeNode>()) {
++count_tuple;
}
- for (Tensor tensor : VisitExpr(arg)) {
+ for (top::Tensor tensor : VisitExpr(arg)) {
inputs.push_back(tensor);
}
}
return outputs;
}
- Array<Tensor> VisitExpr_(const FunctionNode* op) final {
+ Array<top::Tensor> VisitExpr_(const FunctionNode* op) final {
LOG(FATAL) << "Do not support sub function";
- return Array<Tensor>();
+ return Array<top::Tensor>();
}
- Array<Tensor> VisitExpr_(const LetNode* op) final {
- Array<Tensor> val = VisitExpr(op->value);
+ Array<top::Tensor> VisitExpr_(const LetNode* op) final {
+ Array<top::Tensor> val = VisitExpr(op->value);
CHECK(!memo_.count(op->var));
memo_[op->var] = val;
return VisitExpr(op->body);
}
- Array<Tensor> VisitExpr_(const TupleNode* op) final {
- Array<Tensor> fields;
+ Array<top::Tensor> VisitExpr_(const TupleNode* op) final {
+ Array<top::Tensor> fields;
for (Expr field : op->fields) {
CHECK(field->checked_type().as<TensorTypeNode>())
<< "Only allow Tuple of Tensor";
- Array<Tensor> res = VisitExpr(field);
+ Array<top::Tensor> res = VisitExpr(field);
CHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
}
/*! \brief Map from parameter to its shape function usage state */
std::unordered_map<Expr, int, ObjectHash, ObjectEqual> param_states_;
/*! \brief Map from parameter to list of data placeholder */
- std::unordered_map<Expr, Array<Tensor>, ObjectHash, ObjectEqual> param_data_;
+ std::unordered_map<Expr, Array<top::Tensor>, ObjectHash, ObjectEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */
- std::unordered_map<Expr, Array<Tensor>, ObjectHash, ObjectEqual> param_shapes_;
+ std::unordered_map<Expr, Array<top::Tensor>, ObjectHash, ObjectEqual> param_shapes_;
/*! \brief Memoized visit result */
- std::unordered_map<Expr, Array<Tensor>, ObjectHash, ObjectEqual> memo_;
+ std::unordered_map<Expr, Array<top::Tensor>, ObjectHash, ObjectEqual> memo_;
/*! \brief Stack of data dependencies for shape function */
std::vector<bool> data_dependants_;
/*! \brief Scalars used in the shape function */
- Array<Tensor> scalars_;
+ Array<top::Tensor> scalars_;
};
class CompileEngineImpl : public CompileEngineNode {
* \return Pair of schedule and cache.
* The funcs field in cache is not yet populated.
*/
- std::pair<Schedule, CachedFunc> CreateSchedule(
+ std::pair<top::Schedule, CachedFunc> CreateSchedule(
const Function& source_func, const Target& target) {
return ScheduleGetter(target).Create(source_func);
}
cache_node->func_name = GetUniqueName(cache_node->func_name);
// NOTE: array will copy on write.
- Array<Tensor> all_args = cache_node->inputs;
- for (Tensor arg : cache_node->outputs) {
+ Array<top::Tensor> all_args = cache_node->inputs;
+ for (top::Tensor arg : cache_node->outputs) {
all_args.push_back(arg);
}
// lower the function
spair.first, all_args, cache_node->func_name, key->source_func);
} else {
tvm::BuildConfig bcfg = BuildConfig::Create();
- std::unordered_map<Tensor, Buffer> binds;
+ std::unordered_map<top::Tensor, Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
}
value->cached_func = CachedFunc(cache_node);
cache_node->func_name = GetUniqueName(cache_node->func_name);
cache_node->target = key->target;
- Array<Tensor> all_args = cache_node->inputs;
- for (Tensor arg : cache_node->outputs) {
+ Array<top::Tensor> all_args = cache_node->inputs;
+ for (top::Tensor arg : cache_node->outputs) {
all_args.push_back(arg);
}
tvm::BuildConfig bcfg = BuildConfig::Create();
- std::unordered_map<Tensor, Buffer> binds;
+ std::unordered_map<top::Tensor, Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
value->cached_func = CachedFunc(cache_node);
return value;
/*! \brief Function name */
std::string func_name;
/* \brief The inputs to the function */
- tvm::Array<Tensor> inputs;
+ tvm::Array<top::Tensor> inputs;
/* \brief The outputs to the function */
- tvm::Array<Tensor> outputs;
+ tvm::Array<top::Tensor> outputs;
/*! \brief The lowered functions to support the function. */
tvm::Array<tvm::LoweredFunc> funcs;
/*! \brief Parameter usage states in the shape function. */
#include <tvm/build_module.h>
#include <tvm/codegen.h>
#include <tvm/ir_pass.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <typeinfo>
#include <string>
* \brief A compiler from relay::Module to the VM byte code.
*/
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
- Array<Tensor> outputs;
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
+ Array<top::Tensor> outputs;
for (size_t i = 0; i < inputs.size(); ++i) {
outputs.push_back(topi::identity(inputs[i]));
}
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
TVM_REGISTER_NODE_TYPE(DebugAttrs);
-Array<Tensor> DebugCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> DebugCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
- return Array<Tensor>{ topi::identity(inputs[0]) };
+ return Array<top::Tensor>{ topi::identity(inputs[0]) };
}
RELAY_REGISTER_OP("debug")
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
- [](const Attrs& attrs, const Array<Tensor>& inputs,
- const Type& out_dtype, const Target& target) -> Array<Tensor> {
+ [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+ const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
return {topi::identity(inputs[0])};
});
.add_argument("bias", "1D Tensor", "Bias.")
.set_support_level(1)
.add_type_rel("BiasAdd", BiasAddRel)
-.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<Tensor>& inputs,
+.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<top::Tensor>& inputs,
const Type& out_type, const Target& target) {
const auto* param = attrs.as<BiasAddAttrs>();
- return tvm::Array<tvm::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
+ return tvm::Array<tvm::top::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
});
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
- const Array<Tensor>& inputs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<LeakyReluAttrs>();
- return Array<Tensor>{ topi::leaky_relu(inputs[0], param->alpha) };
+ return Array<top::Tensor>{ topi::leaky_relu(inputs[0], param->alpha) };
});
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
- const Array<Tensor>& inputs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<PReluAttrs>();
- return Array<Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)};
+ return Array<top::Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)};
});
.set_support_level(1)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
- const Array<Tensor>& inputs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<SoftmaxAttrs>();
CHECK(param != nullptr);
- return Array<Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
+ return Array<top::Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
});
.set_support_level(1)
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
- const Array<Tensor>& inputs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<SoftmaxAttrs>();
CHECK(param != nullptr);
CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
<< "log_softmax currently only works on last dimension";
- return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) };
+ return Array<top::Tensor>{ topi::nn::log_softmax(inputs[0]) };
});
.add_type_rel("BatchFlatten", BatchFlattenRel)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
- const Array<Tensor>& inputs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
- return Array<Tensor>{ topi::nn::flatten(inputs[0]) };
+ return Array<top::Tensor>{ topi::nn::flatten(inputs[0]) };
});
.add_type_rel("Identity", IdentityRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
- const Array<Tensor>& inputs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
- return Array<Tensor>{ topi::relu(inputs[0], 0.0f) };
+ return Array<top::Tensor>{ topi::relu(inputs[0], 0.0f) };
});
return true;
}
-Array<Tensor> PadCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> PadCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<PadAttrs>();
pad_after.push_back(pad_width[i][1]);
}
const auto* out_ttype = out_type.as<TensorTypeNode>();
- return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after,
+ return Array<top::Tensor>{ topi::pad(inputs[0], pad_before, pad_after,
tvm::make_const(out_ttype->dtype, param->pad_value),
"T_pad",
topi::kElementWise,
}
template<typename AttrType, topi::nn::PoolType mode>
-Array<Tensor> Pool2DCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> Pool2DCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
static const Layout kNCHW("NCHW");
}
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool2DAttrs*>(param)->count_include_pad;
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding,
mode, ceil_mode, layout.name(), count_include_pad)};
} else {
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding,
mode, ceil_mode, layout.name())};
}
template<topi::nn::PoolType mode>
-Array<Tensor> GlobalPool2DCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> GlobalPool2DCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
static const Layout kNCHW("NCHW");
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
<< "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::nn::global_pool(inputs[0], mode, layout.name()) };
}
}
template<topi::nn::PoolType mode>
-Array<Tensor> AdaptivePool2DCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> AdaptivePool2DCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
static const Layout kNCHW("NCHW");
output_height = output_size[0];
output_width = output_size[1];
}
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::nn::adaptive_pool(inputs[0], Array<IndexExpr>{ output_height, output_width },
mode, layout.name()) };
}
}
template <typename AttrType, topi::nn::PoolType mode>
-Array<Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<Tensor>& inputs,
+Array<top::Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<top::Tensor>& inputs,
const Type& out_type, const Target& target) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>();
}
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool2DAttrs*>(param)->count_include_pad;
- return Array<Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
+ return Array<top::Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
mode, ceil_mode, layout.name(), count_include_pad)};
} else {
- return Array<Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
+ return Array<top::Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
mode, ceil_mode, layout.name())};
}
}
template<typename AttrType, topi::nn::PoolType mode>
-Array<Tensor> Pool1DCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> Pool1DCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
static const Layout kNCW("NCW");
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool1DAttrs*>(param)->count_include_pad;
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::nn::pool1d(inputs[0], pool_size, strides, padding,
mode, ceil_mode, layout.name(), count_include_pad)};
} else {
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::nn::pool1d(inputs[0], pool_size, strides, padding,
mode, ceil_mode, layout.name())};
}
template<typename AttrType, topi::nn::PoolType mode>
-Array<Tensor> Pool3DCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> Pool3DCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
static const Layout kNCDHW("NCDHW");
}
if (mode == topi::nn::kAvgPool) {
bool count_include_pad = reinterpret_cast<const AvgPool3DAttrs*>(param)->count_include_pad;
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::nn::pool3d(inputs[0], pool_size, strides, padding,
mode, ceil_mode, layout.name(), count_include_pad)};
} else {
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::nn::pool3d(inputs[0], pool_size, strides, padding,
mode, ceil_mode, layout.name())};
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define RELAY_BINARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \
- const Array<Tensor>& inputs, \
+ const Array<top::Tensor>& inputs, \
const Type& out_type, \
- const Target& target) -> Array<Tensor> { \
+ const Target& target) -> Array<top::Tensor> { \
CHECK_EQ(inputs.size(), 2U); \
return {FTOPI(inputs[0], inputs[1])}; \
} \
}
template<typename F>
-Array<Tensor> ReduceCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> ReduceCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target,
F f) {
.add_argument("data", "Tensor", "The input tensor.")
-Array<Tensor> ArgMaxCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> ArgMaxCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::argmax);
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<Tensor> ArgMinCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> ArgMinCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::argmin);
.set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<Tensor> SumCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> SumCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::sum);
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<Tensor> AllCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> AllCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::all);
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<Tensor> AnyCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> AnyCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::any);
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<Tensor> MaxCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> MaxCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::max);
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<Tensor> MinCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> MinCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::min);
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<Tensor> ProdCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> ProdCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::prod);
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
-Array<Tensor> MeanCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> MeanCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
IndexExpr count = make_const(inputs[0]->dtype, 1);
return true;
}
-Array<Tensor> VarianceCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> VarianceCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
IndexExpr count = make_const(inputs[0]->dtype, 1);
return true;
}
-Array<Tensor> CastCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> CastCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const CastAttrs *param = attrs.as<CastAttrs>();
}
-Array<Tensor> CastLikeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> CastLikeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::cast(inputs[0], inputs[1]->dtype) };
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
-Array<Tensor> ReinterpretCompute(const Attrs& attrs, const Array<Tensor>& inputs,
+Array<top::Tensor> ReinterpretCompute(const Attrs& attrs, const Array<top::Tensor>& inputs,
const Type& out_type, const Target& target) {
const CastAttrs* param = attrs.as<CastAttrs>();
CHECK(param != nullptr);
return true;
}
-Array<Tensor> ExpandDimsCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> ExpandDimsCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const ExpandDimsAttrs *param = attrs.as<ExpandDimsAttrs>();
// relay.concatenate
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
-Array<Tensor> ConcatenateCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> ConcatenateCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const ConcatenateAttrs *param = attrs.as<ConcatenateAttrs>();
return true;
}
-Array<Tensor> StackCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> StackCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const StackAttrs *param = attrs.as<StackAttrs>();
return true;
}
-Array<Tensor> TransposeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> TransposeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<TransposeAttrs>();
CHECK(param != nullptr);
- return Array<Tensor>{ topi::transpose(inputs[0], param->axes) };
+ return Array<top::Tensor>{ topi::transpose(inputs[0], param->axes) };
}
Expr MakeTranspose(Expr data,
return true;
}
-Array<Tensor> ReshapeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> ReshapeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return true;
}
-Array<Tensor> TakeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> TakeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);
if (!param->axis.defined()) {
- return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
+ return Array<top::Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
} else {
- return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
+ return Array<top::Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
}
}
return true;
}
-Array<Tensor> FullCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> FullCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return true;
}
-Array<Tensor> FullLikeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> FullLikeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::full_like(inputs[0], inputs[1]()) };
}
}
-inline Tensor DynamicArange(const tvm::Tensor& start,
- const tvm::Tensor& stop,
- const tvm::Tensor& step,
- tvm::DataType dtype,
- std::string name = "tensor",
- std::string tag = topi::kInjective) {
+inline top::Tensor DynamicArange(const top::Tensor& start,
+ const top::Tensor& stop,
+ const top::Tensor& step,
+ tvm::DataType dtype,
+ std::string name = "tensor",
+ std::string tag = topi::kInjective) {
tvm::PrimExpr num_elem = tvm::Var("num_elem");
- return tvm::compute({num_elem}, [&](const Array<tvm::Var>& indices) {
+ return top::compute({num_elem}, [&](const Array<tvm::Var>& indices) {
return tvm::cast(dtype, start[0] + step[0] * indices[0]);
}, name, tag);
}
-Array<Tensor> ArangeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
- const Type& out_type,
- const Target& target) {
+Array<top::Tensor> ArangeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
+ const Type& out_type,
+ const Target& target) {
const ArangeAttrs* param = attrs.as<ArangeAttrs>();
- Tensor start = inputs[0];
- Tensor stop = inputs[1];
- Tensor step = inputs[2];
+ top::Tensor start = inputs[0];
+ top::Tensor stop = inputs[1];
+ top::Tensor step = inputs[2];
Array<tvm::PrimExpr> empty = {0};
return { DynamicArange(start, stop, step, param->dtype) };
}
return true;
}
-Array<Tensor> RepeatCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> RepeatCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const RepeatAttrs *param = attrs.as<RepeatAttrs>();
return true;
}
-Array<Tensor> TileCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> TileCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const TileAttrs *param = attrs.as<TileAttrs>();
return true;
}
-Array<Tensor> ReverseCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> ReverseCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const ReverseAttrs *param = attrs.as<ReverseAttrs>();
return CallNode::make(op, {condition, x, y});
}
-Array<Tensor> WhereCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> WhereCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::where(inputs[0], inputs[1], inputs[2]) };
return true;
}
-Array<Tensor> SqueezeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> SqueezeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const SqueezeAttrs *param = attrs.as<SqueezeAttrs>();
return CallNode::make(op, {data, collapse_type}, Attrs(), {});
}
-Array<Tensor> CollapseSumLikeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> CollapseSumLikeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return CallNode::make(op, {data}, Attrs(attrs), {});
}
-Array<Tensor> BroadCastToCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> BroadCastToCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
auto ioattrs = attrs.as<InitOpAttrs>();
return CallNode::make(op, {data, broadcast_type}, Attrs(), {});
}
-Array<Tensor> BroadCastToLikeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> BroadCastToLikeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return CallNode::make(op, {data}, Attrs(attrs), {});
}
-Array<Tensor> StridedSliceCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> StridedSliceCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
CHECK(param != nullptr);
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::strided_slice(inputs[0], param->begin, param->end, param->strides)
};
}
return true;
}
-Array<Tensor> SplitCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> SplitCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto param = attrs.as<SplitAttrs>();
if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
int64_t num_sections = sections->value;
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::split_sections(inputs[0], num_sections, param->axis) };
} else {
auto indices = Downcast<Array<Integer> >(param->indices_or_sections);
- return Array<Tensor>{ topi::split(inputs[0], indices, param->axis) };
+ return Array<top::Tensor>{ topi::split(inputs[0], indices, param->axis) };
}
}
return CallNode::make(op, {data, shape_like}, Attrs(attrs), {});
}
-Array<Tensor> SliceLikeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> SliceLikeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<SliceLikeAttrs>();
<< topi::GetConstInt(src_shape[axis]);
}
}
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::strided_slice(inputs[0],
GetIntArray(begin_idx),
GetIntArray(end_idx),
// relay.layout_transform
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
-Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> LayoutTransformCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<LayoutTransformAttrs>();
CHECK(param != nullptr);
- return Array<Tensor>{
+ return Array<top::Tensor>{
topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)
};
}
return true;
}
-Array<Tensor> GatherNDCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> GatherNDCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::gather_nd(inputs[0], inputs[1]) };
return true;
}
-Array<Tensor> SequenceMaskCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> SequenceMaskCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<SequenceMaskAttrs>();
CHECK(param != nullptr);
- return Array<Tensor>{ topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) };
+ return Array<top::Tensor>{
+ topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) };
}
Expr MakeSequenceMask(Expr data,
return true;
}
-Array<Tensor> OneHotCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> OneHotCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<OneHotAttrs>();
CHECK(param != nullptr);
- return Array<Tensor> {
+ return Array<top::Tensor> {
topi::one_hot(inputs[0],
inputs[1](),
inputs[2](),
#define RELAY_UNARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \
- const Array<Tensor>& inputs, \
+ const Array<top::Tensor>& inputs, \
const Type& out_type, \
- const Target& target) -> Array<Tensor> { \
+ const Target& target) -> Array<top::Tensor> { \
return {FTOPI(inputs[0])}; \
} \
return true;
}
-Array<Tensor> ShapeOfCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> ShapeOfCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
CHECK_EQ(inputs.size(), 1);
return true;
}
-Array<Tensor> NdarraySizeCompute(const Attrs& attrs,
- const Array<Tensor>& inputs,
+Array<top::Tensor> NdarraySizeCompute(const Attrs& attrs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
CHECK_EQ(inputs.size(), 1);
const auto* param = attrs.as<NdarraySizeAttrs>();
CHECK(param != nullptr);
- return Array<Tensor>{topi::ndarray_size(inputs[0], param->dtype)};
+ return Array<top::Tensor>{topi::ndarray_size(inputs[0], param->dtype)};
}
TVM_REGISTER_GLOBAL("relay.op.contrib._make.ndarray_size")
.set_attrs_type<YoloReorgAttrs>()
.add_type_rel("YoloReorg", YoloReorgRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
- const Array<Tensor>& inputs,
+ const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* params = attrs.as<YoloReorgAttrs>();
CHECK(params != nullptr);
- return Array<Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
+ return Array<top::Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
});
} // namespace relay
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tuple>
#include <vector>
#include <functional>
Expr new_e;
bool modified = false;
if (falter_layout.count(op)) {
- tvm::Array<tvm::Tensor> tinfos;
+ tvm::Array<tvm::top::Tensor> tinfos;
for (auto expr : ref_call->args) {
auto ttype = expr->type_as<TensorTypeNode>();
- tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype));
+ tinfos.push_back(tvm::top::placeholder(ttype->shape, ttype->dtype));
}
Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos);
if (altered_value.defined()) {
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tuple>
#include <vector>
#include <functional>
Expr new_e;
bool modified = false;
if (fconvert_layout.count(op)) {
- tvm::Array<tvm::Tensor> tinfos;
+ tvm::Array<tvm::top::Tensor> tinfos;
for (auto expr : ref_call->args) {
auto ttype = expr->type_as<TensorTypeNode>();
- tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype));
+ tinfos.push_back(tvm::top::placeholder(ttype->shape, ttype->dtype));
}
Expr altered_value =
fconvert_layout[op](ref_call->attrs, new_args, tinfos, operator->()->desired_layout_);
*/
#include <tvm/lowered_func.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
* shape, dtype or layout to another op or a sequence of ops.
*/
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file auto_inline_elem_wise.cc
- */
-#include <tvm/schedule_pass.h>
-#include <tvm/operation.h>
-#include <tvm/ir_functor_ext.h>
-
-namespace tvm {
-namespace schedule {
-
-using namespace ir;
-
-class ElemWiseDetector : public ir::ExprVisitor {
- public:
- explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
-
- void VisitExpr(const PrimExpr& e) final {
- if (!is_elem_wise_) return;
- ExprVisitor::VisitExpr(e);
- }
-
- void VisitExpr_(const CallNode* op) final {
- Array<PrimExpr> axis = op->args;
- if (axis_.size() != axis.size()) {
- is_elem_wise_ = false;
- return;
- }
-
- for (size_t i = 0; i < axis_.size(); ++i) {
- if (!axis[i].same_as(axis_[i]->var)) {
- is_elem_wise_ = false;
- return;
- }
- }
- ExprVisitor::VisitExpr_(op);
- }
-
- bool is_elem_wise_{true};
-
- private:
- Array<IterVar> axis_;
-};
-
-
-bool IsElemWise(const Operation& op) {
- if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
- ElemWiseDetector v = ElemWiseDetector(compute->axis);
- for (auto& e : compute->body) v(e);
- return v.is_elem_wise_;
- }
- return false;
-}
-
-void AutoInlineElemWise(Schedule sch) {
- for (Stage s : sch->stages) {
- if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) {
- s.compute_inline();
- }
- }
-}
-
-bool IsBroadcast(const Operation& op) {
- if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
- if (compute->reduce_axis.size()) {
- return false;
- }
- // TODO(nicolasvasilache): Implement Me
- }
- return false;
-}
-
-void AutoInlineBroadcast(Schedule sch) {
- for (Stage s : sch->stages) {
- if (!s.is_scheduled() && IsBroadcast(s->op) && !s->is_output) {
- s.compute_inline();
- }
- }
-}
-
-bool IsInjective(const Operation& op) {
- if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
- return compute->reduce_axis.size() == 0;
- }
- return false;
-}
-
-void AutoInlineInjective(Schedule sch) {
- for (Stage s : sch->stages) {
- if (!s.is_scheduled() && IsInjective(s->op) && !s->is_output) {
- s.compute_inline();
- }
- }
-}
-
-} // namespace schedule
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file bound.cc
- * \brief The bound inference logic.
- */
-#include <tvm/schedule_pass.h>
-#include <tvm/operation.h>
-#include <tvm/ir_pass.h>
-#include <unordered_map>
-#include <unordered_set>
-#include "graph.h"
-#include "message_passing.h"
-#include "../runtime/thread_storage_scope.h"
-
-namespace tvm {
-namespace schedule {
-
-using runtime::StorageRank;
-using runtime::StorageScope;
-using runtime::ThreadScope;
-
-/*! \brief The graph context used during bound inference. */
-struct GraphContext {
- /*! \brief The feed graph */
- FeedGraph feed_graph;
- /*! \brief Attachment path */
- AttachPath attach_path;
- /*! \brief The bind map */
- std::unordered_map<IterVar, IterVar> bind_map;
- /*! \brief map from op to stage */
- std::unordered_map<const Object*, Stage> op2stage_;
-};
-
-bool NeedRelax(const IterVar& iv,
- bool found_attach,
- const std::unordered_map<IterVar, IterVar>& bind_map,
- const runtime::StorageScope& scope) {
- auto it = bind_map.find(iv);
- const std::string& tag = (
- it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
- if (tag.length() == 0 || tag == "pipeline") {
- return !found_attach;
- }
- ThreadScope ts = ThreadScope::make(tag);
-
- // When there is warp memory
- // threadIdx.x must be set to be warp index.
- if (scope.rank == StorageRank::kWarp &&
- ts.rank == 1 &&
- ts.dim_index == 0) {
- return true;
- }
- return static_cast<int>(scope.rank) <= ts.rank;
-}
-
-// infer storage scope, if not given
-StorageScope InferStorageScope(
- const Stage& stage, const GraphContext& ctx) {
- if (stage->scope.length() != 0) {
- return StorageScope::make(stage->scope);
- }
- int max_rank = -1;
- for (IterVar iv : ctx.attach_path.at(stage->op)) {
- auto it = ctx.bind_map.find(iv);
- const std::string& tag = (
- it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
- if (tag != "pipeline" && tag.length() != 0) {
- max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
- }
- }
- StorageScope s;
- s.rank = runtime::DefaultStorageRank(max_rank);
- return s;
-}
-
-
-void InferRootBound(const Stage& stage,
- const GraphContext& ctx,
- std::unordered_map<IterVar, Range>* rmap) {
- CHECK_NE(stage->attach_type, kInline)
- << "call schedule.normalize before scheduleops";
- if (stage->attach_type == kInlinedAlready) return;
- if (stage->is_output) {
- // verify correctness.
- CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
- << "Output must be attached at root";
- }
- if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
- for (auto iv : stage->op->root_iter_vars()) {
- CHECK(iv->dom.defined());
- CHECK(!rmap->count(iv));
- (*rmap)[iv] = iv->dom;
- }
- return;
- }
- // The tensor domain.
- std::unordered_map<Tensor, TensorDom> tmap;
- // The consumers of the op.
- std::unordered_set<Operation> consumers;
- for (int i = 0; i < stage->op->num_outputs(); ++i) {
- Tensor t = stage->op.output(i);
- tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
- auto it = ctx.feed_graph.find(t);
- if (it != ctx.feed_graph.end()) {
- for (const Operation& op : it->second) {
- consumers.insert(op);
- }
- } else {
- LOG(INFO) << "not in feed graph consumer = " << stage->op;
- }
- }
- // storage scope.
- runtime::StorageScope scope = InferStorageScope(stage, ctx);
- // Bound prop by other consumers.
- // - Compute bound by relaxation rules: NeedRelax
- // - For normal index, use relative location of loop nest./
- // - For thread index, use the thread scope.
- //
- Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
- // The parent set.
- for (const Operation& op : consumers) {
- std::unordered_map<const VarNode*, IntSet> relax_set;
- std::unordered_map<IterVar, IntSet> up_state;
- bool found_attach = false;
- CHECK(ctx.op2stage_.count(op.get()));
- const Stage& op_stage = ctx.op2stage_.at(op.get());
- // Consumer nest
- for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) {
- IterVar iv = op_stage->leaf_iter_vars[i - 1];
- if (stage_attach.size() != 0 && iv == stage_attach[0]) {
- found_attach = true;
- }
- auto it = rmap->find(iv);
- CHECK(it != rmap->end());
- const Range& vrange = it->second;
- if (is_one(vrange->extent)) {
- up_state[iv] = IntSet::single_point(vrange->min);
- } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
- CHECK(is_zero(vrange->min))
- << "InferBound requires every leaf iter var's min equals 0, "
- << " call schedule.normalize to achieve this. ";
- if (ctx.bind_map.count(iv)) {
- up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
- } else {
- up_state[iv] = IntSet::single_point(iv->var);
- }
- } else {
- up_state[iv] = IntSet::range(vrange);
- }
- }
- // Consumer's attach nest
- for (IterVar iv : ctx.attach_path.at(op)) {
- if (stage_attach.size() != 0 && iv == stage_attach[0]) {
- found_attach = true;
- }
- Range vrange = rmap->at(iv);
- CHECK(is_zero(vrange->min))
- << "InferBound requires every leaf iter var's min equals 0, "
- << "call schedule.normalize to achieve this.";
- if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
- relax_set[iv->var.get()] = IntSet::range(vrange);
- if (ctx.bind_map.count(iv)) {
- relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
- }
- }
- }
- CHECK(found_attach || stage_attach.size() == 0)
- << "Invalid Schedule, cannot find the producer " << stage->op
- << " along the loop nest specified by compute_at of consumer " << op;
- // Get the domain of the consumer
- PassUpDomain(op_stage, *rmap, &up_state);
- // Relax if needed.
- std::unordered_map<const VarNode*, IntSet> dom_map;
- arith::Analyzer analyzer;
- for (auto iv : op->root_iter_vars()) {
- Range r;
- if (up_state.count(iv)) {
- r = up_state.at(iv).cover_range(iv->dom);
- } else {
- r = iv->dom;
- }
- if (relax_set.size() != 0) {
- dom_map[iv->var.get()] = EvalSet(r, relax_set);
- } else {
- dom_map[iv->var.get()] = IntSet::range(r);
- }
- analyzer.Bind(iv->var, r);
- }
- op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
- }
- stage->op->GatherBound(stage->op, tmap, rmap);
-}
-
-Map<IterVar, Range> InferBound(const Schedule& sch) {
- // Prepare context
- GraphContext ctx;
- Array<Operation> roots;
- arith::Analyzer analyzer;
-
- for (Operation op : sch->outputs) {
- roots.push_back(sch->stage_map[op]->op);
- }
- ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
-
- for (Stage stage : sch->stages) {
- for (auto kv : stage->iter_var_attrs) {
- if (kv.second->bind_thread.defined()) {
- CHECK(!ctx.bind_map.count(kv.first));
- ctx.bind_map[kv.first] = kv.second->bind_thread;
- }
- }
- ctx.op2stage_[stage->op.get()] = stage;
- }
- ctx.attach_path = CreateAttachPath(sch);
- // Run inference.
- std::unordered_map<IterVar, Range> ret;
- for (size_t i = sch->stages.size(); i != 0; --i) {
- const Stage& stage = sch->stages[i - 1];
- InferRootBound(stage, ctx, &ret);
-
- // bind bound of root iter vars.
- for (auto iv : stage->op->root_iter_vars()) {
- auto it = ret.find(iv);
- if (it != ret.end()) {
- analyzer.Bind(iv->var, it->second);
- }
- }
-
- // pass down to get bound of all iter vars.
- PassDownDomain(stage, &ret, &analyzer);
- for (IterVar iv : stage->env_threads) {
- CHECK(iv->dom.defined());
- ret[iv] = iv->dom;
- }
- }
- for (auto& p : ret) {
- ret[p.first] = Range::make_by_min_extent(
- analyzer.Simplify(p.second->min),
- analyzer.Simplify(p.second->extent));
- }
- return Map<IterVar, Range>(ret.begin(), ret.end());
-}
-
-} // namespace schedule
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file graph.cc
- * \brief Utilities to get information about schedule graph.
- */
-#include <tvm/ir.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/operation.h>
-#include <utility>
-#include <unordered_set>
-#include <unordered_map>
-#include "graph.h"
-
-namespace tvm {
-namespace schedule {
-// key to specific tensor dimension.
-struct TensorDimKey {
- ir::FunctionRef f;
- int value_index;
- int dim;
- TensorDimKey() {}
- TensorDimKey(const ir::CallNode* op, int dim)
- : f(op->func), value_index(op->value_index), dim(dim) {
- }
- TensorDimKey(const Tensor& t, int dim)
- : f(t->op), value_index(t->value_index), dim(dim) {
- }
- TensorDimKey(const Tensor& t, size_t dim)
- : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {
- }
- inline bool operator==(const TensorDimKey& other) const {
- return f == other.f &&
- value_index == other.value_index &&
- dim == other.dim;
- }
- inline bool operator!=(const TensorDimKey& other) const {
- return !operator==(other);
- }
-};
-} // namespace schedule
-} // namespace tvm
-
-namespace std {
-template <>
-struct hash<::tvm::schedule::TensorDimKey> {
- std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const {
- size_t lhs = ::tvm::ObjectHash()(k.f);
- size_t rhs = static_cast<size_t>(k.value_index) << 16UL |
- static_cast<size_t>(k.dim);
- lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
- return lhs;
- }
-};
-} // namespace std
-
-
-namespace tvm {
-namespace schedule {
-
-// construct a read graph that gives readers of each operation
-// that the root depend on
-ReadGraph CreateReadGraph(const Array<Operation>& roots) {
- ReadGraph rmap;
- std::vector<Operation> stack;
- std::unordered_set<const Object*> visited;
- // initialize the roots
- for (Operation op : roots) {
- stack.push_back(op);
- visited.insert(op.get());
- }
-
- while (!stack.empty()) {
- Operation op = stack.back();
- stack.pop_back();
- Array<Tensor> deps = op->InputTensors();
- rmap.Set(op, deps);
- for (Tensor t : deps) {
- if (t->op.defined() && visited.count(t->op.get()) == 0) {
- visited.insert(t->op.get());
- stack.push_back(t->op);
- }
- }
- }
- return rmap;
-}
-
-// Do DFS visit to get the subgraph.
-// Return if op is inside the subgraph.
-bool GetSubGraphByPostDFS_(
- const Operation& op,
- const std::unordered_set<const Object*>& boundary,
- bool include_bounary,
- std::unordered_map<const Object*, bool>* visited,
- Array<Operation>* result) {
- if (visited->count(op.get())) {
- return visited->at(op.get());
- }
- if (boundary.count(op.get())) {
- (*visited)[op.get()] = true;
- if (include_bounary) {
- result->push_back(op);
- }
- return true;
- }
- // mark to avoid loop
- // Not necessary for DAG.
- (*visited)[op.get()] = false;
- // check if we can reach boundary.
- bool reach_boundary = false;
- for (Tensor t : op->InputTensors()) {
- if (GetSubGraphByPostDFS_(t->op, boundary,
- include_bounary,
- visited, result)) {
- reach_boundary = true;
- }
- }
- (*visited)[op.get()] = reach_boundary;
- if (reach_boundary) {
- result->push_back(op);
- }
- return reach_boundary;
-}
-
-Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
- const Array<Tensor>& inputs,
- bool include_inputs) {
- Array<Operation> result;
- std::unordered_set<const Object*> boundary;
- for (Tensor t : inputs) {
- boundary.insert(t->op.get());
- }
- std::unordered_map<const Object*, bool> visited;
- for (Tensor t : outputs) {
- GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
- &visited, &result);
- }
- return result;
-}
-
-
-void PostDFSOrder(const Operation& op,
- const ReadGraph& g,
- std::unordered_set<Operation>* visited,
- Array<Operation>* post_order) {
- if (visited->count(op)) return;
- visited->insert(op);
- for (const auto& t : g.at(op)) {
- PostDFSOrder(t->op, g, visited, post_order);
- }
- post_order->push_back(op);
-}
-
-Array<Operation> PostDFSOrder(
- const Array<Operation>& roots,
- const ReadGraph& g) {
- std::unordered_set<Operation> visited;
- Array<Operation> post_order;
- for (Operation op : roots) {
- PostDFSOrder(op, g, &visited, &post_order);
- }
- return post_order;
-}
-
-FeedGraph CreateFeedGraph(const ReadGraph& g) {
- FeedGraph fg;
- for (auto kv : g) {
- for (Tensor t : kv.second) {
- fg[t].push_back(kv.first);
- }
- }
- return fg;
-}
-
-AttachPath CreateAttachPath(Schedule sch) {
- AttachPath ret;
- for (Stage stage : sch->stages) {
- std::unordered_set<const Object*> visited;
- Array<IterVar> path;
- for (Stage s = stage; s.defined();) {
- CHECK(!visited.count(s.get()))
- << "Find loop in compute_at attach group";
- visited.insert(s.get());
- Stage spec = s.GetAttachSpec();
- bool start_attach;
- IterVar attach_ivar;
- if (spec->attach_type == kScope) {
- attach_ivar = spec->attach_ivar;
- s = spec->attach_stage;
- start_attach = false;
- CHECK(attach_ivar.defined());
- } else if (spec->attach_type == kScanUpdate) {
- s = spec->attach_stage;
- start_attach = true;
- } else {
- break;
- }
- CHECK(s.defined());
- for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
- IterVar iv = s->leaf_iter_vars[i - 1];
- if (!start_attach && iv.same_as(attach_ivar)) {
- start_attach = true;
- }
- if (start_attach) path.push_back(iv);
- }
- CHECK(start_attach)
- << "Invalid Schedule: cannot find attach point " << attach_ivar
- << " in the schedule of " << s->op;
- }
- if (!ret.count(stage->op)) {
- ret.Set(stage->op, path);
- }
- }
- return ret;
-}
-
-// graph of push reach relation of tensor dimensions
-using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;
-
-ReachGraph GetReachGraph(const Array<Operation>& ops) {
- ReachGraph reach;
- std::unordered_set<const Object*> bset;
- for (size_t i = 0; i < ops.size(); ++i) {
- bset.insert(ops[i].get());
- }
-
- for (Operation op : ops) {
- if (const auto* scan_op = op.as<ScanOpNode>()) {
- const auto& update = scan_op->update;
- const auto& init = scan_op->init;
- for (size_t i = 0; i < update.size(); ++i) {
- Tensor t = op.output(i);
- for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
- reach[TensorDimKey(t, k)].emplace_back(
- TensorDimKey(update[i], k));
- reach[TensorDimKey(t, k)].emplace_back(
- TensorDimKey(init[i], k));
- }
- }
- } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
- std::unordered_map<const Object*, TensorDimKey> vmap;
- const auto& axis = compute_op->axis;
- Tensor t = op.output(0);
- for (size_t i = 0; i < axis.size(); ++i) {
- vmap[axis[i]->var.get()] = TensorDimKey(t, i);
- reach[TensorDimKey(t, i)] = {};
- }
- auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) {
- const ir::CallNode *call = n.as<ir::CallNode>();
- if (call != nullptr && call->func.defined()) {
- if (!bset.count(call->func.get())) return;
- for (size_t i = 0; i < call->args.size(); ++i) {
- TensorDimKey dkey(call, static_cast<int>(i));
- auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) {
- const VarNode *v = node.as<VarNode>();
- auto it = vmap.find(v);
- if (it != vmap.end()) {
- reach[it->second].push_back(dkey);
- }
- };
- ir::PostOrderVisit(call->args[i], fpush);
- }
- }
- };
- for (auto& e : compute_op->body) {
- ir::PostOrderVisit(e, fvisit);
- }
- }
- }
- return reach;
-}
-
-Array<Operation> ScanGetBody(const Operation& scan_op) {
- const ScanOpNode* scan = scan_op.as<ScanOpNode>();
- // Get the body.
- Array<Tensor> inputs;
- for (Tensor t : scan->state_placeholder) {
- inputs.push_back(t);
- }
- for (Tensor t : scan->inputs) {
- inputs.push_back(t);
- }
- return GetSubGraph(scan->update, inputs, false);
-}
-
-Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
- const ScanOpNode* scan = scan_op.as<ScanOpNode>();
- Array<Operation> body = ScanGetBody(scan_op);
-
- std::unordered_map<TensorDimKey, const Object*> exact_reach;
- std::unordered_set<const Object*> fail_set;
-
- for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
- for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
- TensorDimKey key(scan->state_placeholder[i], k);
- exact_reach[key] = scan->spatial_axis_[sp_idx].get();
- }
- }
- // merge exact reach
- auto f_merge_key = [&exact_reach, &fail_set](
- const TensorDimKey& dst, const TensorDimKey& src) {
- auto sit = exact_reach.find(src);
- if (sit == exact_reach.end()) return;
- auto dit = exact_reach.find(dst);
- if (dit == exact_reach.end()) {
- exact_reach[dst] = sit->second;
- } else {
- if (dit->second != sit->second) {
- fail_set.insert(dit->second);
- fail_set.insert(sit->second);
- }
- }
- };
- // prop exact reach back.
- for (size_t i = 0; i < body.size(); ++i) {
- const Operation& op = body[i];
- if (const auto* scan_op = op.as<ScanOpNode>()) {
- const auto& update = scan_op->update;
- const auto& init = scan_op->init;
- for (size_t i = 0; i < update.size(); ++i) {
- Tensor t = op.output(i);
- for (size_t k = 1; k < update[i]->shape.size(); ++k) {
- f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
- f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
- }
- }
- } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
- std::unordered_map<const Object*, std::vector<TensorDimKey> > vmap;
- const auto& axis = compute_op->axis;
- for (size_t i = 0; i < axis.size(); ++i) {
- std::vector<TensorDimKey> keys;
- for (int j = 0; j < op->num_outputs(); ++j) {
- keys.emplace_back(op.output(j), i);
- }
- vmap[axis[i]->var.get()] = std::move(keys);
- }
- auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
- const ObjectRef& n) {
- const ir::CallNode *call = n.as<ir::CallNode>();
- if (call != nullptr && call->func.defined()) {
- for (size_t i = 0; i < call->args.size(); ++i) {
- auto it = vmap.find(call->args[i].get());
- TensorDimKey src(call, static_cast<int>(i));
- if (it != vmap.end()) {
- const std::vector<TensorDimKey>& keys = it->second;
- for (const auto& key : keys) {
- f_merge_key(key, src);
- }
- } else {
- if (exact_reach.count(src)) {
- fail_set.insert(exact_reach.at(src));
- }
- }
- }
- }
- };
- for (auto& e : compute_op->body) {
- ir::PostOrderVisit(e, fvisit);
- }
- }
- }
- ReachGraph reach;
- Map<IterVar, PrimExpr> ret;
- std::unordered_set<TensorDimKey> place_holder_ref;
- for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
- for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
- place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
- }
- }
-
- for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
- for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
- TensorDimKey key(scan->update[i], k);
- TensorDimKey target(scan->state_placeholder[i], k);
- IterVar sp_iv = scan->spatial_axis_[sp_idx];
- if (fail_set.count(sp_iv.get()) ||
- !exact_reach.count(key) ||
- exact_reach.at(key) != sp_iv.get()) {
- ret.Set(sp_iv, make_const(DataType::Int(32), 0));
- } else {
- // now we proved exact match, need to prove no interference with other graph.
- if (reach.size() == 0) reach = GetReachGraph(body);
- // do a DFS
- std::unordered_set<TensorDimKey> visited;
- std::vector<TensorDimKey> stack{key};
- visited.insert(key);
- while (!stack.empty()) {
- TensorDimKey k = stack.back();
- if (k != target && place_holder_ref.count(k)) break;
- stack.pop_back();
- if (!reach.count(k)) {
- LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
- }
-
- for (TensorDimKey kk : reach.at(k)) {
- if (visited.count(kk)) {
- continue;
- }
- visited.insert(kk);
- stack.push_back(kk);
- }
- }
- if (!stack.empty()) {
- // failed the prove.
- ret.Set(sp_iv, make_const(DataType::Int(32), 0));
- } else {
- ret.Set(sp_iv, make_const(DataType::Int(32), 1));
- }
- }
- }
- }
- return ret;
-}
-
-} // namespace schedule
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file graph.h
- * \brief Utilities to get information about schedule graph.
- */
-#ifndef TVM_SCHEDULE_GRAPH_H_
-#define TVM_SCHEDULE_GRAPH_H_
-
-#include <tvm/expr.h>
-#include <tvm/schedule.h>
-#include <tvm/operation.h>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-
-namespace tvm {
-namespace schedule {
-
-/*!
- * \brief data structure of Operation->Tensors it reads
- */
-using ReadGraph = Map<Operation, Array<Tensor> >;
-
-/*!
- * \brief AttachPath maps op-> a list of IterVar
- */
-using AttachPath = Map<Operation, Array<IterVar> >;
-
-/*!
- * \brief The map between tensor and operation it feeds to.
- */
-using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
-
-/*!
- * \brief Get read graph of each operation to all the
- * Tensors that it directly depends on.
- *
- * The result map contains Operations needed to finish root Operation.
- * \param roots The root operation.
- * \return The result map.
- */
-ReadGraph CreateReadGraph(const Array<Operation>& roots);
-
-/*!
- * \brief Get minimum subgraph between outputs and inputs.
- * The operations contains node which input-reachable from any inputs
- * output reachable to any outputs.
- *
- * The inputs won't be included in the subgraph, the outputs will be included.
- *
- * \param outputs The outputs of the subgraph
- * \param inputs The inputs to the subgraph.
- * \param include_inputs Whether to include inputs
- *
- * \return The subgraph.
- */
-Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
- const Array<Tensor>& inputs,
- bool include_inputs);
-
-/*!
- * \brief Get a post DFS ordered of operations in the graph.
- * \param roots The root of the graph.
- * \param g The read graph.
- * \return vector order of Operations in PostDFS order.
- *
- * \note PostDFSOrder is a special case of Topoligical order,
- * and can be used when topoligical order is needed.
- */
-Array<Operation> PostDFSOrder(
- const Array<Operation>& roots, const ReadGraph& g);
-
-/*!
- * \brief Create feedgraph for given Schedule
- * \param g The read graph.
- * \return The created feedgraph.
- */
-FeedGraph CreateFeedGraph(const ReadGraph& g);
-
-/*!
- * \brief Create AttachPath that maps op-> a list of IterVar
- * That represents the loop nest op sits in from inner most to outermost
- * Also inserts attach_stage for scan updates when needed.
- *
- * \param sch The schedule.
- * \return The attach path.
- */
-AttachPath CreateAttachPath(Schedule sch);
-
-/*!
- * \brief Get all operations inside the recursion of scan.
- * \param scan_op The scan node ops.
- * \return The body operations, in read dependency order.
- */
-Array<Operation> ScanGetBody(const Operation& scan_op);
-
-/*!
- * \brief Analyze each spatial dimension of scan's result.
- * Give check on whether each dimension is fix point,
- * An axis is a fixed point if it only refers back to itself in recursion
- * and it is not used in axis of other recursion field.
- *
- * next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
- *
- * \param scan The scan node.
- * \return Map of spatial_axis -> IntImm
- */
-Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan);
-
-} // namespace schedule
-} // namespace tvm
-
-#endif // TVM_SCHEDULE_GRAPH_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file message_passing.cc
- * \brief The message passing domain.
- */
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include "message_passing.h"
-#include "../arith/compute_expr.h"
-
-namespace tvm {
-namespace schedule {
-
-using namespace ir;
-
-void Update(std::unordered_map<IterVar, Range>* p_state,
- const IterVar& iv,
- Range r,
- arith::Analyzer* analyzer) {
- auto it = p_state->find(iv);
- if (it == p_state->end()) {
- (*p_state)[iv] = r;
- analyzer->Bind(iv->var, r);
- } else {
- bool match = is_zero(it->second->min) &&
- analyzer->CanProve(r->extent - it->second->extent == 0);
- CHECK(match)
- << iv
- << " domain already inferred,"
- << " cannot prove their extents are the same "
- << it->second->extent << " vs " << r->extent;
- }
-}
-
-void PassDownDomain(const Stage& stage,
- std::unordered_map<IterVar, Range>* p_state,
- arith::Analyzer* actx,
- bool allow_missing) {
- auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
- if (actx->CanProve(indexmod(a, b) == 0)) {
- return actx->Simplify(indexdiv(a, b));
- }
- return actx->Simplify(indexdiv(a + (b - 1), b));
- };
-
- auto& state = *p_state;
- // forwar iteration on relations
- for (IterVarRelation rel : stage->relations) {
- if (const SplitNode* r = rel.as<SplitNode>()) {
- if (!state.count(r->parent)) {
- CHECK(allow_missing);
- continue;
- }
- CHECK(!state.count(r->inner));
- const Range& range_parent = state.at(r->parent);
- if (r->factor.defined()) {
- Update(p_state, r->inner,
- Range::make_by_min_extent(0, r->factor), actx);
- Update(p_state, r->outer,
- Range::make_by_min_extent(
- 0, ceil_div(range_parent->extent, r->factor)), actx);
- } else {
- Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
- Update(p_state, r->inner,
- Range::make_by_min_extent(
- 0, ceil_div(range_parent->extent, r->nparts)), actx);
- }
- } else if (const FuseNode* r = rel.as<FuseNode>()) {
- if (!state.count(r->outer) || !state.count(r->inner)) {
- CHECK(allow_missing);
- continue;
- }
- const Range& range_outer = state.at(r->outer);
- const Range& range_inner = state.at(r->inner);
- state[r->fused] = Range::make_by_min_extent(
- 0, range_outer->extent * range_inner->extent);
- } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
- if (!state.count(r->parent)) {
- CHECK(allow_missing);
- continue;
- }
- Update(p_state, r->rebased,
- Range::make_by_min_extent(
- 0, state.at(r->parent)->extent), actx);
- } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
- Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx);
- } else {
- LOG(FATAL) << "unknown relation type";
- }
- }
- // update the extents of binded threads.
- for (auto kv : stage->iter_var_attrs) {
- if (kv.second->bind_thread.defined()) {
- CHECK(state.count(kv.first));
- Update(p_state, kv.second->bind_thread, state.at(kv.first), actx);
- }
- }
-}
-
-void PassUpIndex(const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, PrimExpr>* p_state,
- bool allow_missing) {
- auto& state = *p_state;
- for (size_t i = stage->relations.size(); i != 0; --i) {
- IterVarRelation rel = stage->relations[i - 1];
- if (const SplitNode* s = rel.as<SplitNode>()) {
- if (!state.count(s->outer) || !state.count(s->inner)) {
- CHECK(allow_missing);
- continue;
- }
- PrimExpr outer = state.at(s->outer);
- PrimExpr inner = state.at(s->inner);
- PrimExpr factor = dom_map.at(s->inner)->extent;
- PrimExpr parent_min = dom_map.at(s->parent)->min;
- state[s->parent] = inner + outer * factor;
- // add min if they exist
- if (!is_zero(parent_min)) {
- state[s->parent] = state[s->parent] + parent_min;
- }
- } else if (const FuseNode* s = rel.as<FuseNode>()) {
- if (!state.count(s->fused)) {
- CHECK(allow_missing);
- continue;
- }
- PrimExpr value = state.at(s->fused);
- PrimExpr factor = dom_map.at(s->inner)->extent;
- PrimExpr outer_min = dom_map.at(s->outer)->min;
- PrimExpr inner_min = dom_map.at(s->inner)->min;
- state[s->outer] = indexdiv(value, factor);
- state[s->inner] = indexmod(value, factor);
- // add min if they exist
- if (!is_zero(outer_min)) {
- state[s->outer] = state[s->outer] + outer_min;
- }
- if (!is_zero(inner_min)) {
- state[s->inner] = state[s->inner] + inner_min;
- }
- } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
- if (!state.count(s->rebased)) {
- CHECK(allow_missing);
- continue;
- }
- PrimExpr value = state.at(s->rebased);
- PrimExpr parent_min = dom_map.at(s->parent)->min;
- // add min if they exist
- if (!is_zero(parent_min)) {
- state[s->parent] = value + parent_min;
- } else {
- state[s->parent] = value;
- }
- } else if (rel.as<SingletonNode>()) {
- } else {
- LOG(FATAL) << "unknown relation type";
- }
- }
-}
-
-void PassDownIndex(const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, PrimExpr>* p_state,
- bool allow_missing) {
- auto& state = *p_state;
- for (IterVarRelation rel : stage->relations) {
- if (const SplitNode* s = rel.as<SplitNode>()) {
- if (!state.count(s->parent)) {
- CHECK(allow_missing);
- continue;
- }
- Range r = dom_map.at(s->inner);
- CHECK(is_zero(r->min));
- PrimExpr parent = state.at(s->parent);
- PrimExpr factor = r->extent;
- state[s->outer] = indexdiv(parent, factor);
- state[s->inner] = indexmod(parent, factor);
- } else if (const FuseNode* s = rel.as<FuseNode>()) {
- if (!state.count(s->inner) && !state.count(s->outer)) {
- CHECK(allow_missing);
- continue;
- }
- PrimExpr factor = dom_map.at(s->inner)->extent;
- PrimExpr outer_min = dom_map.at(s->outer)->min;
- PrimExpr inner_min = dom_map.at(s->inner)->min;
- PrimExpr inner = state.at(s->inner);
- PrimExpr outer = state.at(s->outer);
- CHECK(is_zero(outer_min));
- CHECK(is_zero(inner_min));
- state[s->fused] = outer * factor + inner;
- } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
- if (!state.count(s->rebased)) {
- CHECK(allow_missing);
- continue;
- }
- PrimExpr value = state.at(s->parent);
- PrimExpr parent_min = dom_map.at(s->parent)->min;
- CHECK(is_zero(parent_min));
- state[s->rebased] = value;
- } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
- state[s->iter] = make_zero(s->iter->var.dtype());
- } else {
- LOG(FATAL) << "unknown relation type";
- }
- }
-}
-
-// Domain message passing.
-void PassUpDomain(const SplitNode* s,
- const std::unordered_map<IterVar, Range>& dom_map,
- const IntSet& outer,
- const IntSet& inner,
- IntSet* parent) {
- if (dom_map.count(s->outer) &&
- dom_map.count(s->inner) &&
- dom_map.count(s->parent) &&
- outer.match_range(dom_map.at(s->outer)) &&
- inner.match_range(dom_map.at(s->inner))) {
- *parent = IntSet::range(dom_map.at(s->parent));
- return;
- }
- PrimExpr factor = dom_map.at(s->inner)->extent;
- PrimExpr parent_min = dom_map.at(s->parent)->min;
- CHECK(outer.defined());
- CHECK(inner.defined());
- CHECK(factor.defined());
- *parent = arith::EvalSet(
- s->outer->var * factor + s->inner->var + parent_min,
- {{s->outer, outer}, {s->inner, inner}});
-}
-
-void PassUpDomain(const FuseNode* s,
- const std::unordered_map<IterVar, Range>& dom_map,
- const IntSet& fused,
- IntSet* outer,
- IntSet* inner) {
- CHECK(dom_map.count(s->outer));
- CHECK(dom_map.count(s->inner));
- CHECK(dom_map.count(s->fused));
-
- if (fused.match_range(dom_map.at(s->fused))) {
- *outer = IntSet::range(dom_map.at(s->outer));
- *inner = IntSet::range(dom_map.at(s->inner));
- return;
- }
- PrimExpr outer_min = dom_map.at(s->outer)->min;
- PrimExpr inner_min = dom_map.at(s->inner)->min;
-
- if (fused.is_single_point()) {
- PrimExpr value = fused.point_value();
- PrimExpr factor = dom_map.at(s->inner)->extent;
- PrimExpr v_outer = indexdiv(value, factor);
- PrimExpr v_inner = indexmod(value, factor);
- if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
- if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
- *outer = IntSet::single_point(v_outer);
- *inner = IntSet::single_point(v_inner);
- } else {
- PrimExpr fused_extent = (fused.max() - fused.min() + 1);
- PrimExpr inner_extent = dom_map.at(s->inner)->extent;
- *outer = IntSet::interval(
- outer_min + indexdiv(fused.min(), inner_extent),
- outer_min + indexdiv(fused.max(), inner_extent));
- if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) &&
- is_zero(Simplify(indexmod(fused.min(), fused_extent)))) {
- // fused never spans multiple rows, make a tight bounding box
- // there may be other cases when bounding box could be tightened
- *inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent),
- inner_min + indexmod(fused.max(), inner_extent));
- } else { // fused may span multiple rows, use full row widths
- if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) ||
- !is_zero(Simplify(indexmod(fused.min(), inner_extent)))) {
- LOG(WARNING) <<
- "fused and original axes are not aligned, this may cause redundant computations";
- }
- *inner = IntSet::range(dom_map.at(s->inner));
- }
- return;
- }
-}
-
-void PassUpDomain(const RebaseNode* s,
- const std::unordered_map<IterVar, Range>& dom_map,
- const IntSet& rebased,
- IntSet* parent) {
- CHECK(dom_map.count(s->parent));
- if (rebased.match_range(dom_map.at(s->rebased))) {
- *parent = IntSet::range(dom_map.at(s->parent));
- return;
- }
- PrimExpr parent_min = dom_map.at(s->parent)->min;
- *parent = arith::EvalSet(s->rebased->var + parent_min,
- {{s->rebased, rebased}});
-}
-
-void PassUpDomain(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, IntSet>* p_state) {
- auto& state = *p_state;
- for (size_t i = stage->relations.size(); i != 0; --i) {
- IterVarRelation rel = stage->relations[i - 1];
- if (const SplitNode* r = rel.as<SplitNode>()) {
- IntSet parent;
- PassUpDomain(r, dom_map,
- state.at(r->outer), state.at(r->inner),
- &parent);
- state[r->parent] = parent;
- } else if (const FuseNode* r = rel.as<FuseNode>()) {
- IntSet outer, inner;
- PassUpDomain(r, dom_map,
- state.at(r->fused),
- &outer, &inner);
- state[r->outer] = outer;
- state[r->inner] = inner;
- } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
- IntSet parent;
- PassUpDomain(r, dom_map,
- state.at(r->rebased),
- &parent);
- state[r->parent] = parent;
- } else if (rel.as<SingletonNode>()) {
- } else {
- LOG(FATAL) << "unknown relation type";
- }
- }
-}
-
-// Pass up bit mask with or relation.
-void PassUpBitMaskOr(const Stage& stage,
- std::unordered_map<IterVar, int>* p_state,
- bool allow_missing) {
- auto& state = *p_state;
- for (size_t i = stage->relations.size(); i != 0; --i) {
- IterVarRelation rel = stage->relations[i - 1];
- if (const SplitNode* s = rel.as<SplitNode>()) {
- if (!state.count(s->inner) && !state.count(s->outer)) {
- CHECK(allow_missing);
- continue;
- }
- int res = 0;
- if (!state.count(s->parent)) res |= state[s->parent];
- if (!state.count(s->inner)) res |= state[s->inner];
- if (!state.count(s->outer)) res |= state[s->outer];
- state[s->parent] = res;
- } else if (const FuseNode* s = rel.as<FuseNode>()) {
- if (!state.count(s->fused)) {
- CHECK(allow_missing);
- continue;
- }
- if (!state.count(s->outer)) {
- state[s->outer] = state[s->fused];
- } else {
- state[s->outer] |= state[s->fused];
- }
- if (!state.count(s->inner)) {
- state[s->inner] = state[s->fused];
- } else {
- state[s->inner] |= state[s->fused];
- }
- } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
- if (!state.count(s->rebased)) {
- CHECK(allow_missing);
- continue;
- }
- if (!state.count(s->parent)) {
- state[s->parent] = state[s->rebased];
- } else {
- state[s->parent] |= state[s->rebased];
- }
- } else if (rel.as<SingletonNode>()) {
- } else {
- LOG(FATAL) << "unknown relation type";
- }
- }
-}
-
-void PassDownBitMaskOr(const Stage& stage,
- std::unordered_map<IterVar, int>* p_state,
- bool allow_missing) {
- auto& state = *p_state;
- for (IterVarRelation rel : stage->relations) {
- if (const SplitNode* s = rel.as<SplitNode>()) {
- if (!state.count(s->parent)) {
- CHECK(allow_missing);
- continue;
- }
- if (!state.count(s->outer)) {
- state[s->outer] = state.at(s->parent);
- } else {
- state[s->outer] |= state.at(s->parent);
- }
- if (!state.count(s->inner)) {
- state[s->inner] = state.at(s->parent);
- } else {
- state[s->inner] |= state.at(s->parent);
- }
- } else if (const FuseNode* s = rel.as<FuseNode>()) {
- if (!state.count(s->outer) && !state.count(s->inner)) {
- CHECK(allow_missing);
- continue;
- }
- int res = 0;
- if (state.count(s->outer)) res |= state.at(s->outer);
- if (state.count(s->inner)) res |= state.at(s->inner);
- if (state.count(s->fused)) res |= state.at(s->fused);
- state[s->fused] = res;
- } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
- if (!state.count(s->parent)) {
- CHECK(allow_missing);
- continue;
- }
- if (!state.count(s->rebased)) {
- state[s->rebased] = state.at(s->parent);
- } else {
- state[s->rebased] |= state.at(s->parent);
- }
- } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
- state[s->iter] = 0;
- } else {
- LOG(FATAL) << "unknown relation type";
- }
- }
-}
-
-
-/*!
- * \brief message passing to find if boundary checking on IterVar is needed.
- * \param s The stage to be used.
- * \param p_state The message passing state
- * IterVar->flag
- */
-void PassUpBoundCheck(const Stage& s,
- const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, bool>* p_state,
- arith::Analyzer* analyzer) {
- auto& state = *p_state;
- for (size_t i = s->relations.size(); i != 0; --i) {
- IterVarRelation rel = s->relations[i - 1];
- if (const SplitNode* s = rel.as<SplitNode>()) {
- bool outer = state.at(s->outer);
- bool inner = state.at(s->inner);
-
- if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
- PrimExpr factor = dom_map.at(s->inner)->extent;
- PrimExpr step = dom_map.at(s->outer)->extent;
- if (outer || inner) {
- state[s->parent] = true;
- } else {
- if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
- state[s->parent] = false;
- } else {
- state[s->parent] = true;
- }
- }
- } else {
- state[s->parent] = true;
- }
- } else if (const FuseNode* s = rel.as<FuseNode>()) {
- bool fused = state.at(s->fused);
- state[s->outer] = fused;
- state[s->inner] = fused;
- } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
- state[s->parent] = state.at(s->rebased);
- } else if (rel.as<SingletonNode>()) {
- // nop
- } else {
- LOG(FATAL) << "unknown relation type";
- }
- }
-}
-
-std::vector<PrimExpr> MakeBoundCheck(
- const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, PrimExpr>& value_map,
- bool skip_ivar_domain,
- const std::unordered_set<IterVar>& skip_iter) {
- arith::Analyzer analyzer;
-
- std::unordered_map<IterVar, bool> bound_state;
- for (IterVar iv : stage->leaf_iter_vars) {
- bound_state[iv] = false;
- }
- PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
-
- std::vector<PrimExpr> preds;
- std::unordered_map<const VarNode*, IntSet> iset_dmap;
-
- // setup domain map for set analysis
- for (const auto& kv : dom_map) {
- iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
- }
-
- for (const IterVar& iv : stage->all_iter_vars) {
- if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
- if (bound_state.at(iv)) {
- Range dom = dom_map.at(iv);
- PrimExpr value = value_map.at(iv) - dom->min;
- PrimExpr vmax = EvalSet(value, iset_dmap).max();
- if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
- preds.emplace_back(value < dom->extent);
- }
- }
- }
- for (const IterVar& iv : stage->op->root_iter_vars()) {
- if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
- Range dom = dom_map.at(iv);
- CHECK(iv->dom.defined());
- if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
- PrimExpr value = value_map.at(iv) - iv->dom->min;
- IntSet s = EvalSet(value, iset_dmap);
- PrimExpr vmin = s.min();
- PrimExpr vmax = s.max();
- // The range of `value` resides in [vmin, vmax]
- if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
- preds.emplace_back(value >= 0);
- }
- if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < iv->dom->extent)) {
- preds.emplace_back(value < iv->dom->extent);
- }
- }
- }
- return preds;
-}
-} // namespace schedule
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file message_passing.h
- * \brief Common utilities to do message passing
- * on the schedule hyper graph.
- */
-#ifndef TVM_SCHEDULE_MESSAGE_PASSING_H_
-#define TVM_SCHEDULE_MESSAGE_PASSING_H_
-
-#include <tvm/expr.h>
-#include <tvm/schedule.h>
-#include <tvm/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-
-namespace tvm {
-namespace schedule {
-/*!
- * \brief Downward inference of domain of each IterVar.
- * Caller set the range of the root, then the function
- * propagates it towards the leaves.
- *
- * \param stage The stage to operate on.
- * \param p_state The state of the message passing.
- * \param analyzer Analyzer context, storing information about bounds in p_state.
- * \param allow_missing Whether allow missing value.
- */
-void PassDownDomain(
- const Stage& stage,
- std::unordered_map<IterVar, Range>* p_state,
- arith::Analyzer* analyzer,
- bool allow_missing = false);
-
-/*!
- * \param Upward inference of index of each IterVar.
- * given index assignement of the leaves,
- *
- * \param stage The stage to operate on.
- * \param dom_map The domain map of each iteration variable's domain.
- * \param p_state The index state of each IterVar.
- * \param allow_missing Whether allow missing value.
- */
-void PassUpIndex(const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, PrimExpr>* p_state,
- bool allow_missing = false);
-
-/*!
- * \param Downward inference of index of each IterVar.
- * given index assignement of roots.
- *
- * \param stage The stage to operate on.
- * \param dom_map The domain map of each iteration variable's domain.
- * \param p_state The index state of each IterVar.
- * \param allow_missing Whether allow missing value.
- */
-void PassDownIndex(const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, PrimExpr>* p_state,
- bool allow_missing = false);
-
-/*!
- * \param Upward inference of domain set of each IterVar.
- * given domain assignment of the leaves,
- *
- * \param stage The stage to operate on.
- * \param dom_map The domain map of each iteration variable's maximum domain.
- * \param p_state The index state of each IterVar.
- */
-void PassUpDomain(const Stage& stage,
- const std::unordered_map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, IntSet>* p_state);
-
-/*!
- * \brief Upward message passing of bitmask with or relation.
- * \param stage The stage to operate on.
- * \param p_state The index state of each IterVar.
- * \param allow_missing Whether allow missing value.
- */
-void PassUpBitMaskOr(const Stage& stage,
- std::unordered_map<IterVar, int>* p_state,
- bool allow_missing = false);
-
-/*!
- * \brief Downward message passing of bitmask with or relation.
- * \param stage The stage to operate on.
- * \param p_state The index state of each IterVar.
- * \param allow_missing Whether allow missing value.
- */
-void PassDownBitMaskOr(const Stage& stage,
- std::unordered_map<IterVar, int>* p_state,
- bool allow_missing = false);
-
-/*!
- * \brief Create boundary check predicates given remapped value of root
- * \param stage The stage we operate on
- * \param dom_map The domain map of each value.
- * \param value_map The value map of the root iter var.
- * \param skip_ivar_domain Whether we skip check for IterVar's original domain.
- * \param skip_iter The set of variables to skip bound condition.
- * \return List of predicates that we need to check.
- */
-std::vector<PrimExpr>
-MakeBoundCheck(
- const Stage& stage,
- const Map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, PrimExpr>& value_map,
- bool skip_ivar_domain,
- const std::unordered_set<IterVar>& skip_iter);
-
-} // namespace schedule
-} // namespace tvm
-#endif // TVM_SCHEDULE_MESSAGE_PASSING_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file schedule_dataflow_rewrite.cc
- */
-#include <tvm/schedule.h>
-#include <tvm/operation.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/ir_pass.h>
-#include <unordered_set>
-#include "message_passing.h"
-#include "../pass/ir_util.h"
-#include "../arith/compute_expr.h"
-
-namespace tvm {
-
-// find first occurance location in leaf
-template<typename T>
-size_t FindNodeRef(ArrayNode* array_node, const T& v) {
- const Object* n = v.get();
- for (size_t i = 0; i < array_node->data.size(); ++i) {
- if (array_node->data[i].get() == n) return i;
- }
- return array_node->data.size();
-}
-
-// The replacer of cache.
-class VarReplacer : public ir::StmtExprMutator {
- public:
- explicit VarReplacer(
- const std::unordered_map<const VarNode*, PrimExpr>& vsub)
- : vsub_(vsub) {}
- PrimExpr VisitExpr_(const VarNode* op) final {
- auto it = vsub_.find(op);
- if (it != vsub_.end()) return it->second;
- return GetRef<PrimExpr>(op);
- }
-
- ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
- // Replace free variables in combiner
- auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) {
- return this->VisitExpr(e);
- });
- auto new_result = ir::UpdateArray(combiner->result, [this] (const PrimExpr& e) {
- return this->VisitExpr(e);
- });
-
- if (combiner->identity_element.same_as(new_identity) &&
- combiner->identity_element.same_as(new_result)) {
- return combiner;
- } else {
- return ir::CommReducerNode::make(
- combiner->lhs, combiner->rhs, new_result, new_identity);
- }
- }
-
- PrimExpr VisitExpr_(const ir::ReduceNode* op) final {
- PrimExpr new_e = StmtExprMutator::VisitExpr_(op);
- const ir::ReduceNode* new_reduce = new_e.as<ir::ReduceNode>();
- ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
- if (op->combiner.same_as(new_combiner)) {
- return new_e;
- } else {
- return ir::ReduceNode::make(
- new_combiner,
- new_reduce->source,
- new_reduce->axis,
- new_reduce->condition,
- new_reduce->value_index);
- }
- }
-
- private:
- const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
-};
-
-PrimExpr InjectPredicate(const Array<PrimExpr>& predicates,
- PrimExpr body) {
- using ir::ReduceNode;
- using ir::SelectNode;
- if (predicates.size() == 0) return body;
- const ReduceNode* reduce = body.as<ReduceNode>();
- if (reduce) {
- auto n = make_object<ReduceNode>(*reduce);
- n->condition = n->condition && arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr());
- return PrimExpr(n);
- }
- return SelectNode::make(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()),
- body,
- make_zero(body.dtype()));
-}
-
-// Replace data flow appears in all stages given the tensor change.
-// Also update vmap if subsequent dataflow need to be replaced.
-// Need to keep an update to the date transitive closure property on the vmap by a reverse map.
-void ReplaceDataFlow(const Array<Stage>& stages,
- std::unordered_map<Tensor, Tensor>* vmap,
- std::unordered_map<Tensor, Tensor>* rvmap) {
- for (Stage s : stages) {
- Operation op = s->op->ReplaceInputs(s->op, *vmap);
- if (!op.same_as(s->op)) {
- for (int i = 0; i < op->num_outputs(); ++i) {
- auto it = rvmap->find(s->op.output(i));
- if (it != rvmap->end()) {
- (*vmap)[it->second] = op.output(i);
- } else {
- (*vmap)[s->op.output(i)] = op.output(i);
- (*rvmap)[op.output(i)] = s->op.output(i);
- }
- }
- s->op = op;
- }
- }
-}
-
-inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
- return (a->combiner.same_as(b->combiner)) &&
- (a->source.same_as(b->source)) &&
- (a->axis.same_as(b->axis)) &&
- (a->condition.same_as(b->condition));
-}
-
-Tensor Schedule::cache_read(const Tensor& tensor,
- const std::string& scope,
- const Array<Operation>& readers) {
- (*this)->InvalidateCache();
- // create identity mapping.
- std::ostringstream os;
- os << tensor->op->name;
- if (tensor->op->num_outputs() != 1) {
- os << ".v" << tensor->value_index;
- }
- os << "." << scope;
-
- std::unordered_map<Tensor, Tensor> vsub;
- Stage s = operator[](tensor->op);
- Tensor sugar_tensor = s->op.output(tensor->value_index);
- Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
- return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
- }, os.str());
- vsub[sugar_tensor] = cache;
-
- std::unordered_map<Tensor, Tensor> vmap;
- std::unordered_map<Tensor, Tensor> rvmap;
- for (Operation op : readers) {
- Stage s = operator[](op);
- Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
- CHECK(!repl_op.same_as(s->op))
- << "Cannot find " << tensor
- << " in the inputs of " << s->op;
- vmap[s->op.output(0)] = repl_op.output(0);
- rvmap[repl_op.output(0)] = s->op.output(0);
- s->op = repl_op;
- }
- ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
- ArrayNode* stages = (*this)->stages.CopyOnWrite();
- Stage op_stage = operator[](tensor->op);
- size_t pos = FindNodeRef(stages, op_stage);
- Stage cache_stage = Stage(cache->op);
- cache_stage.set_scope(scope);
- CHECK_LT(pos, stages->data.size());
- stages->data.insert(stages->data.begin() + pos + 1,
- cache_stage);
- (*this)->stage_map.Set(cache->op, cache_stage);
- // Update group
- cache_stage->group = op_stage->group;
- if (cache_stage->group.defined()) {
- ++cache_stage->group->num_child_stages;
- }
- return cache;
-}
-
-template<typename OpType>
-void PrepareAxisMapping(Stage orig_stage,
- OpType* op,
- std::unordered_set<IterVar>* p_red_axis,
- Array<IterVar>* p_new_axis,
- std::unordered_map<IterVar, Range>* p_dom_map,
- std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
- std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
- std::vector<PrimExpr>* p_predicates) {
- auto& red_axis = *p_red_axis;
- auto& new_axis = *p_new_axis;
- auto& dom_map = *p_dom_map;
- auto& vsub = *p_vsub;
- auto& vsub2newvar = *p_vsub2newvar;
- auto& predicates = *p_predicates;
- arith::Analyzer analyzer;
-
- for (IterVar iv : op->reduce_axis) {
- red_axis.insert(iv);
- }
- for (IterVar iv : op->axis) {
- dom_map[iv] = iv->dom;
- analyzer.Bind(iv->var, iv->dom);
- }
- schedule::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
- {
- // The source->cache
- std::unordered_map<IterVar, PrimExpr> value_map;
- for (IterVar iv : orig_stage->leaf_iter_vars) {
- if (red_axis.count(iv)) continue;
- CHECK_EQ(iv->iter_type, kDataPar)
- << "Can only relayout with in data parallel dimensions";
- Range dom = dom_map.at(iv);
- IterVar new_iv = IterVarNode::make(
- dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
- new_axis.push_back(new_iv);
- if (is_one(dom->min)) {
- value_map[iv] = dom->min;
- } else {
- value_map[iv] = iv->var;
- vsub2newvar[iv->var.get()] = new_iv->var;
- }
- }
- // skip reduction iteration.
- std::unordered_set<IterVar> skip_bound_check;
- for (IterVar iv : op->reduce_axis) {
- skip_bound_check.insert(iv);
- }
- schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
- predicates = schedule::MakeBoundCheck(
- orig_stage, dom_map, value_map, true, skip_bound_check);
- // The root axis
- for (IterVar iv : op->axis) {
- if (value_map.count(iv)) {
- vsub[iv->var.get()] = value_map.at(iv);
- } // to handle tensor axis
- }
- }
-}
-
-Array<Tensor> ReplaceOriginalOp(Schedule sch,
- Stage orig_stage,
- const std::string& scope,
- Operation cache_op,
- Operation orig_new_op,
- size_t tensor_size) {
- Array<Tensor> cache_tensor_list;
- for (size_t i = 0; i < tensor_size; i++) {
- Tensor cache_tensor = cache_op.output(i);
- cache_tensor_list.push_back(cache_tensor);
- }
- // The replace of the dataflow
- std::unordered_map<Tensor, Tensor> vmap;
- std::unordered_map<Tensor, Tensor> rvmap;
- vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
- rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
- for (size_t i = 0; i < tensor_size; i++) {
- vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
- rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
- }
- ReplaceDataFlow(sch->stages, &vmap, &rvmap);
- // mutate orig stage
- orig_stage->op = orig_new_op;
- orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
- orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
- orig_stage->relations = Array<IterVarRelation>();
- // create schedule for new cached stage.
- ArrayNode* stages = sch->stages.CopyOnWrite();
- size_t pos = FindNodeRef(stages, orig_stage);
- Stage cache_stage = Stage(cache_op);
- cache_stage.set_scope(scope);
- CHECK_LT(pos, stages->data.size());
- stages->data.insert(stages->data.begin() + pos,
- cache_stage);
- sch->stage_map.Set(cache_op, cache_stage);
- // Update group
- cache_stage->group = orig_stage->group;
- if (cache_stage->group.defined()) {
- ++cache_stage->group->num_child_stages;
- }
- return cache_tensor_list;
-}
-
-
-// Cache write and relayout the data according to loop pattern
-Array<Tensor> CacheWriteWithReLayout(Schedule sch,
- const Array<Tensor>& tensor_array,
- const std::string& scope) {
- size_t tensor_size = tensor_array.size();
- sch->InvalidateCache();
- Tensor tensor = tensor_array[0];
- Stage orig_stage = sch[tensor->op];
- const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
-
- std::unordered_set<IterVar> red_axis;
- Array<IterVar> new_axis;
- std::unordered_map<IterVar, Range> dom_map;
-
- std::unordered_map<const VarNode*, PrimExpr> vsub;
- std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
- std::vector<PrimExpr> predicates;
-
- PrepareAxisMapping(orig_stage, compute,
- &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
-
- PrimExpr body;
- Array<PrimExpr> body_list;
- const ir::ReduceNode* first_reduce = nullptr;
- for (auto cbody : compute->body) {
- body = VarReplacer(vsub)(cbody);
- body = InjectPredicate(predicates, body);
- body = VarReplacer(vsub2newvar)(body);
- // Reduce nodes in ONE computeOp must be the same except value_index
- // This is right only if the original body ensures Reduce nodes are the same
- if (body->IsInstance<ir::ReduceNode>()) {
- const ir::ReduceNode* reduce_body = body.as<ir::ReduceNode>();
- if (first_reduce != nullptr) {
- CHECK(ReduceEqual(reduce_body, first_reduce));
- body = ir::ReduceNode::make(first_reduce->combiner,
- first_reduce->source,
- first_reduce->axis,
- first_reduce->condition,
- reduce_body->value_index);
- } else {
- first_reduce = reduce_body;
- }
- } else {
- CHECK(first_reduce == nullptr)
- << "cannot mix reduce and other node in ONE compute bodys";
- }
- body_list.push_back(body);
- }
- // The reader args
- Array<PrimExpr> args;
- {
- // cache->compute
- std::unordered_map<IterVar, PrimExpr> value_map;
- for (IterVar iv : compute->axis) {
- value_map[iv] = iv->var;
- }
- schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
- for (IterVar iv : orig_stage->leaf_iter_vars) {
- if (red_axis.count(iv)) continue;
- args.push_back(value_map.at(iv));
- }
- }
- Operation cache_op = ComputeOpNode::make(
- compute->name + "." + scope, compute->tag, compute->attrs,
- new_axis, body_list);
-
- Array<PrimExpr> cache_expr_list;
- for (size_t i = 0; i < tensor_size; i++) {
- Tensor cache_tensor = cache_op.output(i);
- cache_expr_list.push_back(cache_tensor(args));
- }
- Operation orig_new_op = ComputeOpNode::make(
- compute->name, compute->tag, compute->attrs,
- compute->axis, cache_expr_list);
- return ReplaceOriginalOp(sch, orig_stage, scope,
- cache_op, orig_new_op, tensor_size);
-}
-
-
-// for tensor compute op
-Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
- const Array<Tensor>& tensor_array,
- const std::string& scope) {
- size_t tensor_size = tensor_array.size();
- sch->InvalidateCache();
- Tensor tensor = tensor_array[0];
- Stage orig_stage = sch[tensor->op];
- const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
- CHECK_EQ(tensor_op->num_outputs(), 1)
- << "cache write only support single output tensor_compute_op";
-
- std::unordered_set<IterVar> red_axis;
- Array<IterVar> new_axis;
- std::unordered_map<IterVar, Range> dom_map;
-
- std::unordered_map<const VarNode*, PrimExpr> vsub;
- std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
- std::vector<PrimExpr> predicates;
-
- PrepareAxisMapping(orig_stage, tensor_op,
- &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
-
-
- for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
- IterVar iv = tensor_op->axis[i];
- IterVar new_iv = IterVarNode::make(
- iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
- new_axis.push_back(new_iv);
- }
- Array<Region> new_regions;
- for (Region old_region : tensor_op->input_regions) {
- Region region;
- for (Range r : old_region) {
- PrimExpr min = VarReplacer(vsub2newvar)(r->min);
- PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
- region.push_back(Range::make_by_min_extent(min, extent));
- }
- new_regions.push_back(region);
- }
-
- Array<PrimExpr> new_scalar_inputs;
- for (PrimExpr old_input : tensor_op->scalar_inputs) {
- new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
- }
-
- Operation cache_op = TensorComputeOpNode::make(
- tensor_op->name + "." + scope, tensor_op->tag, new_axis,
- tensor_op->reduce_axis, tensor_op->schedulable_ndim,
- tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);
-
- // axis will be used in generating compute op
- Array<IterVar> compute_axis = tensor_op->axis;
- for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
- IterVar iv = tensor_op->axis[i];
- IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
- compute_axis.Set(i, aiv);
- }
-
- // The reader args
- Array<PrimExpr> args;
- {
- // cache->compute
- std::unordered_map<IterVar, PrimExpr> value_map;
- for (IterVar iv : compute_axis) {
- value_map[iv] = iv->var;
- }
- schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
- for (IterVar iv : orig_stage->leaf_iter_vars) {
- if (red_axis.count(iv)) continue;
- args.push_back(value_map.at(iv));
- }
- // tensorized region axis
- for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
- IterVar iv = compute_axis[i];
- args.push_back(value_map.at(iv));
- }
- }
-
- Array<PrimExpr> cache_expr_list;
- for (size_t i = 0; i < tensor_size; i++) {
- Tensor cache_tensor = cache_op.output(i);
- cache_expr_list.push_back(cache_tensor(args));
- }
- Operation orig_new_op = ComputeOpNode::make(
- tensor_op->name, tensor_op->tag, {},
- compute_axis, cache_expr_list);
- return ReplaceOriginalOp(sch, orig_stage, scope,
- cache_op, orig_new_op, tensor_size);
-}
-
-
-Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
- const std::string& scope) {
- (*this)->InvalidateCache();
- CHECK(tensor_array.size() > 0)
- << "size of tensor_array must be greater than 0";
- Tensor tensor = tensor_array[0];
- Stage orig_stage = operator[](tensor->op);
- const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
- CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
- << "size of input tensor list must be same as number of stage outputs";
- for (size_t i = 1; i < tensor_array.size(); i++) {
- Stage tmp_stage = operator[](tensor_array[i]->op);
- CHECK(orig_stage.same_as(tmp_stage))
- << "Input tensor list must be generated by ONE computeOp";
- }
- return CacheWriteWithReLayout(*this, tensor_array, scope);
-}
-
-
-Tensor Schedule::cache_write(const Tensor& tensor,
- const std::string& scope) {
- // support original compute and tensor compute both
- (*this)->InvalidateCache();
- if (tensor->op.as<ComputeOpNode>()) {
- return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
- } else if (tensor->op.as<TensorComputeOpNode>()) {
- return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
- } else {
- LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
- return Tensor();
- }
-}
-
-
-void RebaseNonZeroMinLoop(const Schedule& sch) {
- std::unordered_map<IterVar, IterVar> rebase_map;
- for (Stage s : sch->stages) {
- if (s->attach_type == kInlinedAlready) continue;
-
- auto root_iter_vars = s->op->root_iter_vars();
- ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
- for (IterVar iv : root_iter_vars) {
- size_t idx = FindNodeRef(leaf_vars, iv);
- auto it = s->iter_var_attrs.find(iv);
- // don;t need to rebase path that are binded.
- if (it != s->iter_var_attrs.end() &&
- (*it).second->bind_thread.defined()) {
- continue;
- }
- if (idx < leaf_vars->data.size()) {
- // insert rebase
- IterVar rebased = IterVarNode::make(
- Range(), iv->var.copy_with_suffix(""), iv->iter_type);
- s->relations.push_back(RebaseNode::make(iv, rebased));
- if (s->iter_var_attrs.count(iv)) {
- s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
- }
- leaf_vars->data[idx] = rebased;
- rebase_map[iv] = rebased;
- }
- }
- }
- // remap the parent relation
- for (Stage s : sch->stages) {
- if (s->attach_type != kScope) continue;
- if (rebase_map.count(s->attach_ivar)) {
- s->attach_ivar = rebase_map.at(s->attach_ivar);
- }
- }
- for (Stage s : sch->groups) {
- if (s->attach_type != kScope) continue;
- if (rebase_map.count(s->attach_ivar)) {
- s->attach_ivar = rebase_map.at(s->attach_ivar);
- }
- }
-}
-
-void InjectInline(ScheduleNode* sch) {
- sch->InvalidateCache();
-
- std::vector<Array<PrimExpr> > new_body(sch->stages.size());
- std::vector<bool> changed(sch->stages.size(), false);
- std::vector<Stmt> new_hybrid_body(sch->stages.size());
- std::vector<bool> hybrid_changed(sch->stages.size(), false);
- // inline all the ops
- for (size_t i = sch->stages.size(); i != 0; --i) {
- Stage stage = sch->stages[i - 1];
- if (stage->attach_type == kInline) {
- stage->attach_type = kInlinedAlready;
- Array<Var> args;
- PrimExpr body;
- {
- // setup args
- const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
- CHECK(compute)
- << "can only inline compute op";
- for (auto iv : compute->axis) {
- args.push_back(iv->var);
- }
- CHECK_EQ(compute->body.size(), 1U)
- << "can only inline compute op with 1 output";
- body = compute->body[0];
- }
- for (size_t j = i; j < sch->stages.size(); ++j) {
- Stage s = sch->stages[j];
- const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
- const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
- if (compute) {
- if (!new_body[j].size()) {
- new_body[j] = compute->body;
- }
- if (new_body[j][0]->IsInstance<ir::ReduceNode>()) {
- // specially handle reduction inline for multiplre reductions.
- const ir::ReduceNode* reduce = new_body[j][0].as<ir::ReduceNode>();
- for (size_t k = 1; k < new_body[j].size(); ++k) {
- const ir::ReduceNode* reduce_ = new_body[j][k].as<ir::ReduceNode>();
- CHECK(reduce_);
- CHECK(ReduceEqual(reduce_, reduce))
- << "The Reduce inputs of ComputeOp should "
- << "have the same attribute except value_index";
- }
- PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]),
- stage->op, args, body).as<ir::EvaluateNode>()->value;
- if (!new_value.same_as(new_body[j][0])) {
- changed[j] = true;
- const ir::ReduceNode* r = new_value.as<ir::ReduceNode>();
- CHECK_EQ(new_body[j].size(), r->source.size());
- CHECK(r != nullptr);
- for (size_t k = 0; k < new_body[j].size(); ++k) {
- auto n = make_object<ir::ReduceNode>(*r);
- n->value_index = static_cast<int>(k);
- n->dtype = r->source[k].dtype();
- new_body[j].Set(k, PrimExpr(n));
- }
- }
- } else {
- for (size_t k = 0; k < new_body[j].size(); ++k) {
- PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]),
- stage->op, args, body).as<ir::EvaluateNode>()->value;
- if (!new_value.same_as(new_body[j][k])) {
- new_body[j].Set(k, new_value);
- changed[j] = true;
- }
- }
- }
- } else if (hybrid) {
- if (!new_hybrid_body[j].defined()) {
- new_hybrid_body[j] = hybrid->body;
- }
- Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body);
- if (!new_stmt.same_as(new_hybrid_body[j])) {
- new_hybrid_body[j] = new_stmt;
- hybrid_changed[j] = true;
- }
- }
- }
- }
- }
- std::unordered_map<Tensor, Tensor> repl;
- // rewrite dataflow
- for (size_t i = 0; i < sch->stages.size(); ++i) {
- Stage s = sch->stages[i];
- if (s->attach_type == kInlinedAlready) continue;
- if (new_body[i].size()) {
- // Logics from ReplaceDataFlow
- const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
- CHECK(compute);
- Operation op = s->op;
- if (changed[i]) {
- op = ComputeOpNode::make(
- compute->name, compute->tag, compute->attrs,
- compute->axis, new_body[i]);
- }
- op = op->ReplaceInputs(op, repl);
- if (!op.same_as(s->op)) {
- for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
- repl[s->op.output(idx)] = op.output(idx);
- }
- s->op = op;
- }
- } else if (hybrid_changed[i]) {
- const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
- CHECK(hybrid);
- Operation op = HybridOpNode::make(
- hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
- hybrid->outputs, new_hybrid_body[i]);
- op = op->ReplaceInputs(op, repl);
- for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
- repl[s->op.output(idx)] = op.output(idx);
- }
- s->op = op;
- } else {
- Operation op = s->op->ReplaceInputs(s->op, repl);
- if (!op.same_as(s->op)) {
- for (int j = 0; j < op->num_outputs(); ++j) {
- repl[s->op.output(j)] = op.output(j);
- }
- s->op = op;
- }
- }
- }
-}
-
-Schedule Schedule::normalize() {
- Schedule sn = copy();
- InjectInline(sn.operator->());
- RebaseNonZeroMinLoop(sn);
- return sn;
-}
-
-// Handle reduction factor.
-Array<Tensor> Schedule::rfactor(const Tensor& tensor,
- const IterVar& axis,
- int factor_axis) {
- (*this)->InvalidateCache();
- using ir::ReduceNode;
- CHECK_EQ(axis->iter_type, kCommReduce)
- << "Can only factor reduction axis";
- Stage reduce_stage = operator[](tensor->op);
- const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
- CHECK(compute_op) << "Can only factor ComputeOp";
- ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
- {
- size_t axis_pos = FindNodeRef(leaf_vars, axis);
- CHECK_NE(axis_pos, leaf_vars->data.size())
- << "Cannot find IterVar " << axis << " in leaf iter vars";
- }
- // Find touched reduction axis.
- std::unordered_map<IterVar, int> touch_map;
- touch_map[axis] = 1;
- schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
- schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
- // skip reduction iteration.
- std::unordered_set<IterVar> skip_bound_check;
- // Verify normal axis are not touched.
- for (IterVar iv : compute_op->axis) {
- CHECK(!touch_map.count(iv))
- << "Factor axis touches normal axis.";
- skip_bound_check.insert(iv);
- }
- // get analyzer.
- arith::Analyzer analyzer;
- // Get the replace index
- std::unordered_map<IterVar, Range> dom_map;
- std::unordered_map<IterVar, PrimExpr> value_map;
- for (IterVar iv : compute_op->reduce_axis) {
- if (touch_map.count(iv)) {
- dom_map[iv] = iv->dom;
- } else {
- skip_bound_check.insert(iv);
- }
- analyzer.Bind(iv->var, iv->dom);
- }
- schedule::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
- for (IterVar iv : reduce_stage->leaf_iter_vars) {
- if (touch_map.count(iv)) {
- Range dom = dom_map.at(iv);
- if (is_one(dom->extent)) {
- value_map[iv] = dom->min;
- } else {
- value_map[iv] = iv->var;
- }
- }
- }
- schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
- std::vector<PrimExpr> predicates = schedule::MakeBoundCheck(
- reduce_stage, dom_map, value_map, true, skip_bound_check);
-
- // Get the factored op node.
- const int factor_axis_pos = \
- factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
- CHECK_LE(factor_axis_pos, compute_op->axis.size());
- auto n = make_object<ComputeOpNode>();
- n->name = compute_op->name + ".rf";
- {
- // axis relacement.
- auto iv_node = make_object<IterVarNode>();
- iv_node->dom = dom_map.at(axis);
- CHECK(is_zero(iv_node->dom->min))
- << "Can only factor reduction domain starting from 0";
- iv_node->var = axis->var;
- iv_node->iter_type = kDataPar;
-
- const int size = compute_op->axis.size();
- for (int idx = 0; idx < size; ++idx) {
- if (factor_axis_pos == idx) {
- n->axis.push_back(IterVar(iv_node));
- }
- n->axis.push_back(compute_op->axis[idx]);
- }
- if (factor_axis_pos == size) {
- n->axis.push_back(IterVar(iv_node));
- }
- }
- // predicate generation, copy not touched axis.
- int idx = tensor->value_index;
- const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
- CHECK(reduce) << "Can only rfactor non-inline reductions";
- predicates.push_back(reduce->condition);
- PrimExpr predicate = likely(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()));
-
- std::unordered_map<const VarNode*, PrimExpr> vsub;
-
- for (IterVar iv : compute_op->reduce_axis) {
- if (!touch_map.count(iv)) {
- n->reduce_axis.push_back(iv);
- } else {
- CHECK(value_map.count(iv));
- PrimExpr index = value_map.at(iv);
- vsub[iv->var.get()] = index;
- }
- }
-
- // Copy touched axis.
- for (IterVar iv : reduce_stage->leaf_iter_vars) {
- if (touch_map.count(iv) && !iv.same_as(axis)) {
- CHECK_EQ(iv->iter_type, kCommReduce);
- auto ncpy = make_object<IterVarNode>(*iv.operator->());
- ncpy->dom = dom_map.at(iv);
- n->reduce_axis.push_back(IterVar(ncpy));
- }
- }
- VarReplacer replacer(vsub);
- Array<PrimExpr> new_source = ir::UpdateArray(reduce->source,
- [&replacer] (const PrimExpr& e) { return replacer(e); });
-
- PrimExpr new_pred = replacer(predicate);
-
- std::vector<PrimExpr> body;
- for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
- body.emplace_back(ReduceNode::make(reduce->combiner,
- new_source,
- n->reduce_axis,
- new_pred,
- idx));
- }
- n->body = Array<PrimExpr>(body);
- // refresh relations, keep the un-touched relations.
- Array<IterVarRelation> rels;
- for (IterVarRelation rel : reduce_stage->relations) {
- bool touched = false;
- if (const SplitNode* r = rel.as<SplitNode>()) {
- if (touch_map.count(r->parent)) touched = true;
- } else if (const FuseNode* r = rel.as<FuseNode>()) {
- if (touch_map.count(r->fused)) touched = true;
- } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
- if (touch_map.count(r->parent)) touched = true;
- } else {
- LOG(FATAL) << "unknown relation type";
- }
- if (!touched) {
- rels.push_back(rel);
- }
- }
- // initialize the factored stage.
- Operation factor_op(n);
- ArrayNode* stages = (*this)->stages.CopyOnWrite();
- size_t stage_pos = FindNodeRef(stages, reduce_stage);
- Stage factor_stage = Stage(factor_op);
- factor_stage->relations = rels;
- CHECK_LT(stage_pos, stages->data.size());
- stages->data.insert(stages->data.begin() + stage_pos,
- factor_stage);
- (*this)->stage_map.Set(factor_op, factor_stage);
- factor_stage->group = reduce_stage->group;
- if (factor_stage->group.defined()) {
- ++factor_stage->group->num_child_stages;
- }
- // Replace the old reduction.
- IterVar repl_red_axis = reduce_axis(
- dom_map.at(axis), axis->var->name_hint + ".v");
- Array<Tensor> factor_tensors;
- Array<Tensor> old_tensors;
- int size = factor_op->num_outputs();
- for (int idx = 0; idx < size; ++idx) {
- factor_tensors.push_back(factor_op.output(idx));
- old_tensors.push_back(reduce_stage->op.output(idx));
- }
- Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
- [&](const Array<Var>& i) {
- Array<PrimExpr> indices;
- const int idx_size = static_cast<int>(i.size());
- for (int idx = 0; idx < idx_size; ++idx) {
- if (factor_axis_pos == idx) {
- indices.push_back(repl_red_axis->var);
- }
- indices.push_back(i[idx]);
- }
- if (factor_axis_pos == idx_size) {
- indices.push_back(repl_red_axis->var);
- }
- Array<PrimExpr> factor_exprs;
- for (int idx = 0; idx < size; ++idx) {
- factor_exprs.push_back(factor_tensors[idx](indices));
- }
- Array<PrimExpr> reductions;
- Array<IterVar> axis = {repl_red_axis};
- PrimExpr cond = const_true();
- for (int idx = 0; idx < size; ++idx) {
- reductions.push_back(ReduceNode::make(reduce->combiner,
- factor_exprs, axis, cond, idx));
- }
- return reductions;
- }, reduce_stage->op->name + ".repl");
-
- std::unordered_map<Tensor, Tensor> vmap;
- std::unordered_map<Tensor, Tensor> rvmap;
- for (int idx = 0; idx < size; ++idx) {
- vmap[old_tensors[idx]] = repl_tensors[idx];
- rvmap[repl_tensors[idx]] = old_tensors[idx];
- }
- ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
- // revamp the reduction stage.
- reduce_stage->op = repl_tensors[0]->op;
- reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
- reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
- reduce_stage->relations = Array<IterVarRelation>();
- return factor_tensors;
-}
-
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file schedule_lang.cc
- */
-#include <tvm/schedule.h>
-#include <tvm/operation.h>
-#include <unordered_set>
-#include "graph.h"
-
-namespace tvm {
-
-namespace {
-
-// find first occurance location in leaf
-template<typename T>
-size_t FindNodeRef(ArrayNode* array_node, const T& v) {
- const Object* n = v.get();
- for (size_t i = 0; i < array_node->data.size(); ++i) {
- if (array_node->data[i].get() == n) return i;
- }
- return array_node->data.size();
-}
-
-size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) {
- size_t pos = FindNodeRef(leaf_vars, v);
- if (pos < leaf_vars->data.size()) return pos;
-
- if (FindNodeRef(all_vars, v) < all_vars->data.size()) {
- LOG(FATAL) << "Operate on iter var " << v
- << "that has already been split";
- } else {
- LOG(FATAL) << "Operate on iter var " << v
- << "that is not part of the schedule";
- }
- return 0;
-}
-
-void Split(StageNode* self,
- IterVar parent,
- PrimExpr factor,
- PrimExpr nparts,
- IterVar* p_outer,
- IterVar* p_inner) {
- // Check if split is valid.
- CHECK(parent->iter_type == kDataPar ||
- parent->iter_type == kCommReduce ||
- parent->iter_type == kOrdered)
- << "Cannot split on " << IterVarType2String(parent->iter_type);
- IterVar outer = IterVarNode::make(
- Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
- IterVar inner = IterVarNode::make(
- Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
- *p_outer = outer;
- *p_inner = inner;
- // The splits
- ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
- ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
- size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
- self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts));
- // add vars to all vars
- all_vars->data.push_back(outer);
- all_vars->data.push_back(inner);
- // replace the position.
- leaf_vars->data.erase(leaf_vars->data.begin() + pos);
- leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner);
- leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer);
-}
-
-} // namespace
-
-Stage::Stage(Operation op) {
- auto n = make_object<StageNode>();
- n->op = op;
- n->origin_op = op;
- n->all_iter_vars = op->root_iter_vars();
- // remove opaque var from leaf.
- Array<IterVar> clean;
- for (IterVar iv : n->all_iter_vars) {
- if (iv->iter_type != kOpaque) clean.push_back(iv);
- }
- if (clean.size() == n->all_iter_vars.size()) {
- n->leaf_iter_vars = n->all_iter_vars;
- } else {
- n->leaf_iter_vars = clean;
- }
- data_ = std::move(n);
-}
-
-bool Stage::is_scheduled() const {
- const StageNode* n = operator->();
- return !(n->relations.empty() && n->attach_type == kGroupRoot &&
- n->all_iter_vars.same_as(n->leaf_iter_vars));
-}
-
-Stage Stage::GetAttachSpec() const {
- Stage attach_spec = *this;
- while (attach_spec->attach_type == kGroupRoot &&
- attach_spec->group.defined()) {
- attach_spec = attach_spec->group;
- }
- return attach_spec;
-}
-
-Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
- (*this)->scope = scope;
- return *this;
-}
-
-Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
- CHECK_NE((*this)->attach_type, kScanUpdate)
- << "Cannot specify compute_at for scan updates";
- // Group constraint checking.
- Stage group = (*this)->group;
- if (group.defined()) {
- Stage pg = parent->group;
- while (pg.defined() && !pg.same_as(group)) {
- pg = pg->group;
- }
- CHECK(pg.same_as(group))
- << "Can only assign compute_at to stages within the same group";
- }
-
- (*this)->attach_type = kScope;
- (*this)->attach_ivar = scope;
- (*this)->attach_stage = parent;
- bool found = false;
- for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
- if (scope == parent->leaf_iter_vars[i]) {
- found = true; break;
- }
- }
- CHECK(found)
- << "Cannot find the axis " << scope
- << " in parent's leaf_iter_vars"
- << " parent=" << parent;
- return *this;
-}
-
-Stage& Stage::compute_inline() { // NOLINT(*)
- CHECK_NE((*this)->attach_type, kScanUpdate)
- << "Cannot specify compute_at for scan updates";
- (*this)->attach_type = kInline;
- return *this;
-}
-
-Stage& Stage::compute_root() { // NOLINT(*)
- CHECK_NE((*this)->attach_type, kScanUpdate)
- << "Cannot specify compute_at for scan updates";
- (*this)->attach_type = kGroupRoot;
- return *this;
-}
-
-Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*)
- StageNode* self = operator->();
- CHECK(ivar->iter_type == kDataPar ||
- ivar->iter_type == kCommReduce)
- << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread";
- CHECK(thread_ivar->iter_type == kThreadIndex)
- << "Cannot rebase by " << IterVarType2String(ivar->iter_type)
- << ", only thread axis is allowed so far";
- ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
- ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
- FindLeafVar(all_vars, leaf_vars, ivar);
-
- auto it = self->iter_var_attrs.find(ivar);
- ObjectPtr<IterVarAttrNode> n;
- if (it != self->iter_var_attrs.end()) {
- n = make_object<IterVarAttrNode>(*(*it).second.operator->());
- if (n->bind_thread.defined() &&
- !n->bind_thread.same_as(thread_ivar)) {
- LOG(WARNING) << "Axis " << ivar
- << " is already bind to another thread " << n->bind_thread;
- }
- } else {
- n = make_object<IterVarAttrNode>();
- }
- n->bind_thread = thread_ivar;
- self->iter_var_attrs.Set(ivar, IterVarAttr(n));
- return *this;
-}
-
-Stage& Stage::env_threads(Array<IterVar> threads) {
- StageNode* self = operator->();
- CHECK(self->op.defined() && self->op.as<ScanOpNode>())
- << "env_threads is only valid for composite ops such as ScanOp";
- CHECK_EQ(self->env_threads.size(), 0U)
- << "Already set env_threads";
- ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
- ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
- std::vector<ObjectRef> temp;
- for (IterVar iv : threads) {
- temp.push_back(iv);
- }
- leaf_vars->data.insert(
- leaf_vars->data.begin(), temp.begin(), temp.end());
- all_vars->data.insert(
- all_vars->data.end(), temp.begin(), temp.end());
- self->env_threads = threads;
- return *this;
-}
-
-Stage& Stage::set_store_predicate(PrimExpr predicate) {
- StageNode* self = operator->();
- self->store_predicate = predicate;
- return *this;
-}
-
-Stage& Stage::split(
- IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
- Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
- return *this;
-}
-
-Stage& Stage::split_by_nparts(
- IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
- Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
- return *this;
-}
-
-Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*)
- StageNode* self = operator->();
- CHECK(outer->iter_type == kDataPar ||
- outer->iter_type == kCommReduce ||
- outer->iter_type == kOrdered)
- << "Cannot fuse " << IterVarType2String(outer->iter_type);
- CHECK(inner->iter_type == kDataPar ||
- inner->iter_type == kCommReduce ||
- inner->iter_type == kOrdered)
- << "Cannot fuse " << IterVarType2String(inner->iter_type);
-
- IterVarType iter_type = outer->iter_type;
- if (inner->iter_type > iter_type) iter_type = inner->iter_type;
- std::string fused_name =
- outer->var->name_hint + "." + inner->var->name_hint + ".fused";
-
- IterVar fused = IterVarNode::make(
- Range(), Var(fused_name, outer->var.dtype()), iter_type);
-
- ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
- ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
-
- size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner);
- size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer);
- if (pos_inner + 1 == pos_outer) {
- std::swap(outer, inner);
- std::swap(pos_inner, pos_outer);
- }
- self->relations.push_back(FuseNode::make(outer, inner, fused));
- all_vars->data.push_back(fused);
- CHECK_EQ(pos_inner, pos_outer + 1)
- << "Can only fuse iterations that are consecutive between each other";
- leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
- leaf_vars->data.begin() + pos_inner + 1);
- leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
- fused);
- *p_target = fused;
- return *this;
-}
-
-Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(*)
- if (axes.size() != 0) {
- IterVar fused = axes[0];
- for (size_t i = 1; i < axes.size(); ++i) {
- this->fuse(fused, axes[i], &fused);
- }
- *p_target = std::move(fused);
- } else {
- StageNode* self = operator->();
- // special handle fuse empty array.
- // insert at the outer most loop
- IterVar singleton = IterVarNode::make(
- Range::make_by_min_extent(0, 1),
- Var("singleton", DataType::Int(32)), kDataPar);
- self->relations.push_back(SingletonNode::make(singleton));
- ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
- ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
- all_vars->data.push_back(singleton);
- leaf_vars->data.insert(leaf_vars->data.begin(), singleton);
- *p_target = singleton;
- }
- return *this;
-}
-
-Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
- std::unordered_set<IterVar> seen_var;
- StageNode* self = operator->();
- for (IterVar iv : order) {
- CHECK(iv->iter_type == kDataPar ||
- iv->iter_type == kCommReduce ||
- iv->iter_type == kThreadIndex)
- << "Cannot reorder IterVar("
- << IterVarType2String(iv->iter_type) << ")";
-
- CHECK_EQ(seen_var.count(iv), 0)
- << "Same axis can not appear more than once " << iv;
- seen_var.insert(iv);
- }
- ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
- ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
- std::vector<size_t> pos;
-
- for (size_t i = 0; i < order.size(); ++i) {
- pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
- }
- std::vector<ObjectRef> temp;
- for (size_t i = 0; i < pos.size(); ++i) {
- temp.emplace_back(leaf_vars->data[pos[i]]);
- }
- std::sort(pos.begin(), pos.end());
- for (size_t i = 0; i < pos.size(); ++i) {
- leaf_vars->data[pos[i]] = temp[i];
- }
- return *this;
-}
-
-Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
- PrimExpr x_factor, PrimExpr y_factor,
- IterVar* p_x_outer, IterVar* p_y_outer,
- IterVar* p_x_inner, IterVar* p_y_inner) {
- split(x_parent, x_factor, p_x_outer, p_x_inner);
- split(y_parent, y_factor, p_y_outer, p_y_inner);
- reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
- return *this;
-}
-
-template<typename FUpdate>
-inline void UpdateIterVarAttr(StageNode* self,
- IterVar var,
- FUpdate fupdate,
- bool need_leaf = true) {
- if (need_leaf) {
- ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
- ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
- FindLeafVar(all_vars, leaf_vars, var);
- }
- auto it = self->iter_var_attrs.find(var);
- ObjectPtr<IterVarAttrNode> n;
- if (it != self->iter_var_attrs.end()) {
- n = make_object<IterVarAttrNode>(*(*it).second.operator->());
- } else {
- n = make_object<IterVarAttrNode>();
- }
- fupdate(n.get());
- self->iter_var_attrs.Set(var, IterVarAttr(n));
-}
-
-inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
- UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) {
- n->iter_type = iter_type;
- });
-}
-
-Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
- CHECK(var->iter_type == kDataPar ||
- var->iter_type == kOpaque ||
- var->iter_type == kUnrolled ||
- var->iter_type == kVectorized ||
- var->iter_type == kTensorized ||
- var->iter_type == kParallelized)
- << "Cannot vectorize on " << IterVarType2String(var->iter_type);
- SetAttrIterType(operator->(), var, kVectorized);
- return *this;
-}
-
-Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*)
- UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) {
- n->iter_type = kTensorized;
- n->tensor_intrin = f;
- });
- return *this;
-}
-
-Stage& Stage::unroll(IterVar var) { // NOLINT(*)
- SetAttrIterType(operator->(), var, kUnrolled);
- return *this;
-}
-
-Stage& Stage::parallel(IterVar var) { // NOLINT(*)
- SetAttrIterType(operator->(), var, kParallelized);
- return *this;
-}
-
-Stage& Stage::pragma(IterVar var,
- const std::string& pragma_type,
- const PrimExpr& pragma_value) { // NOLINT(*)
- if (pragma_type == "unroll") {
- this->unroll(var);
- } else if (pragma_type == "vectorize") {
- this->vectorize(var);
- } else {
- UpdateIterVarAttr(
- operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
- n->pragma_keys.push_back(ir::StringImmNode::make(pragma_type));
- n->pragma_values.push_back(pragma_value);
- });
- }
- return *this;
-}
-
-Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) {
- StageNode *self = operator->();
- ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
- ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
- FindLeafVar(all_vars, leaf_vars, var);
- auto it = self->iter_var_attrs.find(var);
- ObjectPtr<IterVarAttrNode> n;
- if (it != self->iter_var_attrs.end()) {
- n = make_object<IterVarAttrNode>(*(*it).second.operator->());
- } else {
- n = make_object<IterVarAttrNode>();
- }
- n->prefetch_data.push_back(tensor);
- n->prefetch_offset.push_back(offset);
- self->iter_var_attrs.Set(var, IterVarAttr(n));
- return *this;
-}
-
-Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
- StageNode *self = operator->();
- UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) {
- n->dim_align_factor = factor;
- n->dim_align_offset = offset;
- }, false);
- return *this;
-}
-
-Stage& Stage::double_buffer() {
- StageNode *self = operator->();
- CHECK(!self->is_output) << "Cannot apply double buffer on output";
- self->double_buffer = true;
- return *this;
-}
-
-Stage& Stage::opengl() {
- CHECK(!is_scheduled()) << "Must be a fresh schedule";
- StageNode *self = operator->();
-
- auto all_iter_vars = self->all_iter_vars; // curr version of all_iter_vars
- CHECK(!all_iter_vars.empty()) << "At least one iter var";
-
- // Fuse all data parallel dimensions to 1.
- IterVar fused = all_iter_vars[0];
- for (size_t i = 1; i != all_iter_vars.size(); ++i) {
- auto iter_var = all_iter_vars[i];
- switch (iter_var->iter_type) {
- case IterVarType::kDataPar: {
- fuse(fused, all_iter_vars[i], &fused);
- break;
- }
- case IterVarType::kThreadIndex: {
- LOG(ERROR) << "A fresh schedule shouldn't have thread index iter var";
- break;
- }
- case IterVarType::kCommReduce:
- case IterVarType::kOrdered:
- case IterVarType::kOpaque: {
- break;
- }
- default: {
- LOG(ERROR) << "Invalid iter var type "
- << IterVarType2String(iter_var->iter_type);
- break;
- }
- }
- }
-
- // Bind the only dimension to threadIdx.x.
- bind(fused, thread_axis(Range(nullptr), "threadIdx.x"));
-
- // Mark this stage as OpenGL.
- (*this)->is_opengl = true;
-
- return *this;
-}
-
-Stage CopyStage(const Stage& s) {
- ObjectPtr<StageNode> n =
- make_object<StageNode>(*s.operator->());
- return Stage(n);
-}
-
-Schedule Schedule::copy() const {
- // map of stages.
- const ScheduleNode* self = operator->();
- std::unordered_map<Stage, Stage, ObjectHash, ObjectEqual> smap;
- ObjectPtr<ScheduleNode> n = make_object<ScheduleNode>();
- n->outputs = self->outputs;
- // Copy the stages.
- for (Stage s : self->stages) {
- Stage scopy = CopyStage(s);
- smap[s] = scopy;
- n->stages.push_back(scopy);
- }
- for (Stage g : self->groups) {
- Stage gcopy = CopyStage(g);
- smap[g] = gcopy;
- n->groups.push_back(gcopy);
- }
- // Remaps the reference relations.
- for (auto kv : self->stage_map) {
- n->stage_map.Set(kv.first, smap.at(kv.second));
- }
- for (Stage s : n->stages) {
- if (s->attach_stage.defined()) {
- CHECK(smap.find(s->attach_stage) != smap.end())
- << s->attach_stage << " not found in " << (*this);
- s->attach_stage = smap.at(s->attach_stage);
- }
- if (s->group.defined()) {
- CHECK(smap.find(s->group) != smap.end())
- << s->group << " not found in " << (*this);
- s->group = smap.at(s->group);
- }
- }
- for (Stage s : n->groups) {
- if (s->attach_stage.defined()) {
- CHECK(smap.find(s->attach_stage) != smap.end())
- << s->attach_stage << " not found in " << (*this);
- s->attach_stage = smap.at(s->attach_stage);
- }
- if (s->group.defined()) {
- CHECK(smap.find(s->group) != smap.end())
- << s->group << " not found in " << (*this);
- s->group = smap.at(s->group);
- }
- }
- return Schedule(n);
-}
-
-Stage Schedule::operator[](const Operation& op) {
- auto it = (*this)->stage_map.find(op);
- CHECK(it != (*this)->stage_map.end())
- << "Cannot find Stage for operator " << op
- << " in the schedule";
- return (*it).second;
-}
-
-Stage LeastCommonAncestor(Stage g1, Stage g2) {
- if (!g1.defined()) return g1;
- if (!g2.defined()) return g2;
- if (g1.same_as(g2)) return g1;
- Stage g = g1;
- while (g.defined()) {
- if (g.same_as(g2)) return g2;
- g = g->group;
- }
- g = g2;
- while (g.defined()) {
- if (g.same_as(g1)) return g1;
- g = g->group;
- }
- return g;
-}
-
-Array<Tensor> RemapTensor(ScheduleNode* self,
- const Array<Tensor>& arr) {
- self->InitCache();
- const auto& op2stage_cache = self->op2stage_cache_;
- Array<Tensor> ret;
- for (Tensor t : arr) {
- if (!op2stage_cache.count(t->op.get())) {
- CHECK(self->stage_map.count(t->op))
- << "Given tensor is not in the schedule plan";
- t = self->stage_map[t->op]->op.output(t->value_index);
- }
- ret.push_back(t);
- }
- return ret;
-}
-
-// Group the schedule stages.
-Stage Schedule::create_group(const Array<Tensor>& outputs,
- const Array<Tensor>& inputs,
- bool include_inputs) {
- ScheduleNode* self = operator->();
- self->InitCache();
- const auto& op2stage_cache = self->op2stage_cache_;
- // Get the ops.
- Array<Operation> ops = schedule::GetSubGraph(
- RemapTensor(self, outputs),
- RemapTensor(self, inputs),
- include_inputs);
- // local counter entry
- // Automatically initialize to 0 during creation.
- struct Entry {
- int count{0};
- };
- // Map of group->touched counter
- std::unordered_map<Stage, Entry, ObjectHash, ObjectEqual> counter;
- // The parent group;
- Stage parent_group;
- // Detect common parent and child.
- for (size_t i = 0; i < ops.size(); ++i) {
- Operation op = ops[i];
- auto it = op2stage_cache.find(op.get());
- CHECK(it != op2stage_cache.end());
- Stage op_group = it->second->group;
- if (i == 0) {
- parent_group = op_group;
- } else {
- parent_group = LeastCommonAncestor(parent_group, op_group);
- }
- if (op_group.defined()) {
- ++counter[op_group].count;
- }
- }
- // Create the new group stage.
- Stage gstage(make_object<StageNode>());
- gstage->group = parent_group;
- if (parent_group.defined()) {
- ++parent_group->num_child_stages;
- }
- // Propagate the counter statistics from by checking if subgroup
- // Is full and propagate.
- std::vector<Stage> stack;
- for (auto &kv : counter) {
- if (!kv.first.same_as(parent_group)) {
- if (kv.first->num_child_stages == kv.second.count) {
- stack.push_back(kv.first);
- }
- }
- }
- while (!stack.empty()) {
- Stage g = stack.back();
- stack.pop_back();
- if (g->group.defined() && !g->group.same_as(parent_group)) {
- Entry& e = counter[g->group];
- ++e.count;
- if (e.count == g->group->num_child_stages) {
- stack.push_back(g->group);
- }
- }
- }
- // Verification and remappig the subgroups.
- for (auto &kv : counter) {
- if (kv.first.same_as(parent_group)) continue;
- CHECK_EQ(kv.first->num_child_stages, kv.second.count)
- << "Trying to group region that intersect with an already existed group";
- if (kv.first->group.same_as(parent_group)) {
- Stage s = kv.first;
- s->group = gstage;
- ++gstage->num_child_stages;
- if (parent_group.defined()) {
- --parent_group->num_child_stages;
- }
- }
- }
- // Remap the group of op stages.
- for (Operation op : ops) {
- auto it = op2stage_cache.find(op.get());
- CHECK(it != op2stage_cache.end());
- Stage s = it->second;
- if (s->group.same_as(parent_group)) {
- s->group = gstage;
- ++gstage->num_child_stages;
- if (parent_group.defined()) {
- --parent_group->num_child_stages;
- }
- }
- }
- // Correct the attach to keep everything in group.
- for (Operation op : ops) {
- auto it = op2stage_cache.find(op.get());
- CHECK(it != op2stage_cache.end());
- Stage s = it->second;
- if (s->attach_type == kScope) {
- Stage cg = LeastCommonAncestor(s->attach_stage->group, gstage);
- if (!cg.same_as(gstage)) {
- LOG(WARNING) << "group invalidates some previous compute_at relation "
- << " and keeps things to be computed inside the group";
- s.compute_root();
- }
- }
- }
-
- self->groups.push_back(gstage);
- return gstage;
-}
-
-void ScheduleNode::InvalidateCache() {
- op2stage_cache_.clear();
-}
-
-void ScheduleNode::InitCache() {
- if (op2stage_cache_.size() == stages.size()) return;
- InvalidateCache();
- for (Stage s : stages) {
- if (s->op.defined()) {
- op2stage_cache_[s->op.get()] = s;
- }
- }
- CHECK_EQ(op2stage_cache_.size(), stages.size());
-}
-
-bool ScheduleNode::Contain(const Operation& op) const {
- return stage_map.find(op) != stage_map.end();
-}
-
-Schedule ScheduleNode::make(Array<Operation> ops) {
- auto n = make_object<ScheduleNode>();
- Schedule sch(n);
- n->outputs = ops;
- auto g = schedule::CreateReadGraph(n->outputs);
- Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
- // output set.
- std::unordered_set<Operation> output_set;
- for (Operation x : ops) {
- output_set.insert(x);
- }
- for (Operation op : post_order) {
- Stage stage(op);
- stage->is_output = output_set.count(op) != 0;
- n->stages.push_back(stage);
- n->stage_map.Set(op, stage);
- // mark scan updates.
- if (const ScanOpNode* scan = op.as<ScanOpNode>()) {
- Array<Tensor> inputs;
- for (Tensor t : scan->state_placeholder) {
- inputs.push_back(t);
- }
- for (Tensor t : scan->inputs) {
- inputs.push_back(t);
- }
- // Create the scan group.
- Stage scan_group = sch.create_group(scan->update, inputs, false);
- scan_group->attach_type = kScanUpdate;
- scan_group->attach_stage = stage;
-
- for (size_t i = 0; i < scan->update.size(); ++i) {
- Stage s = n->stage_map[scan->update[i]->op];
- CHECK(scan_group.same_as(s->group));
- }
- }
- }
- return sch;
-}
-
-IterVarRelation SplitNode::make(IterVar parent,
- IterVar outer,
- IterVar inner,
- PrimExpr factor,
- PrimExpr nparts) {
- auto n = make_object<SplitNode>();
- n->parent = parent;
- n->outer = outer;
- n->inner = inner;
- n->factor = factor;
- n->nparts = nparts;
- return IterVarRelation(n);
-}
-
-IterVarRelation FuseNode::make(
- IterVar outer, IterVar inner, IterVar fused) {
- auto n = make_object<FuseNode>();
- n->outer = outer;
- n->inner = inner;
- n->fused = fused;
- return IterVarRelation(n);
-}
-
-IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
- auto n = make_object<RebaseNode>();
- n->parent = parent;
- n->rebased = rebased;
- return IterVarRelation(n);
-}
-
-IterVarRelation SingletonNode::make(IterVar iter) {
- auto n = make_object<SingletonNode>();
- n->iter = iter;
- return IterVarRelation(n);
-}
-
-TVM_REGISTER_NODE_TYPE(StageNode);
-TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
-TVM_REGISTER_NODE_TYPE(SplitNode);
-TVM_REGISTER_NODE_TYPE(FuseNode);
-TVM_REGISTER_NODE_TYPE(RebaseNode);
-TVM_REGISTER_NODE_TYPE(SingletonNode);
-TVM_REGISTER_NODE_TYPE(ScheduleNode);
-
-// Printer
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<StageNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const StageNode*>(node.get());
- if (op->op.defined()) {
- p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
- } else {
- p->stream << "group-stage(" << op << ")";
- }
-})
-.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const IterVarAttrNode*>(node.get());
- p->stream << IterVarType2String(op->iter_type);
-})
-.set_dispatch<SplitNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const SplitNode*>(node.get());
- p->stream << "split(parent=";
- p->Print(op->parent);
- p->stream << ", outer=";
- p->Print(op->outer);
- p->stream << ", inner=";
- p->Print(op->inner);
- p->stream << ')';
-})
-.set_dispatch<FuseNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const FuseNode*>(node.get());
- p->stream << "split(";
- p->stream << "outer=";
- p->Print(op->outer);
- p->stream << ", inner=";
- p->Print(op->inner);
- p->stream << ", fused=";
- p->Print(op->fused);
- p->stream << ')';
-})
-.set_dispatch<RebaseNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const RebaseNode*>(node.get());
- p->stream << "rebase(";
- p->stream << "parent=";
- p->Print(op->parent);
- p->stream << ", rebased=";
- p->Print(op->rebased);
- p->stream << ')';
-})
-.set_dispatch<SingletonNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const SingletonNode*>(node.get());
- p->stream << "singleton(";
- p->Print(op->iter);
- p->stream << ')';
-})
-.set_dispatch<ScheduleNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const ScheduleNode*>(node.get());
- p->stream << "schedule(" << op << ")";
- });
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file schedule_ops.cc
- */
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/operation.h>
-#include <tvm/schedule_pass.h>
-#include <utility>
-#include <unordered_map>
-#include <unordered_set>
-#include "graph.h"
-#include "../op/op_util.h"
-#include "../pass/ir_util.h"
-
-namespace tvm {
-namespace schedule {
-
-using namespace ir;
-
-Stmt MakePipeline(const Stage& s,
- const std::unordered_map<IterVar, Range>& dom_map,
- Stmt consumer,
- bool debug_keep_trivial_loop) {
- Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
- if (producer.defined()) {
- producer = ProducerConsumerNode::make(s->op, true, producer);
- }
- if (s->double_buffer) {
- producer = AttrStmtNode::make(
- s->op, ir::attr::double_buffer_scope, 1, producer);
- }
- Stmt pipeline = producer;
-
- if (consumer.defined() && !is_no_op(consumer)) {
- consumer = ProducerConsumerNode::make(s->op, false, consumer);
- pipeline = SeqStmt({producer, consumer});
- }
- pipeline = s->op->BuildRealize(s, dom_map, pipeline);
- // use attribute to mark scope of the operation.
- pipeline = AttrStmtNode::make(
- s->op, ir::attr::realize_scope,
- StringImmNode::make(s->scope),
- pipeline);
-
- if (s->is_opengl) {
- pipeline = AttrStmtNode::make(
- s->op, ir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline);
- }
- return pipeline;
-}
-
-// inject the operator's realization on the stmt.
-class InjectAttach : public StmtMutator {
- public:
- InjectAttach(const Stage& stage,
- const Stage& attach_spec,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool debug_keep_trivial_loop)
- : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
- debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
-
- Stmt VisitStmt(const Stmt& input_stmt) final {
- CHECK(input_stmt.defined());
- auto stmt = StmtMutator::VisitStmt(input_stmt);
- const AttrStmtNode* op = stmt.as<AttrStmtNode>();
- if (op != nullptr &&
- op->attr_key == attr::loop_scope) {
- if (attach_spec_->attach_type == kScope &&
- op->node == attach_spec_->attach_ivar) {
- CHECK(!found_attach)
- << "Find IterVar" << attach_spec_->attach_ivar
- << " in multiple places in the IR";
- found_attach = true;
- stmt = AttrStmtNode::make(
- op->node, op->attr_key, op->value,
- MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
- }
- }
- return stmt;
- }
- // whether attach point is found
- bool found_attach{false};
-
- private:
- // The stage.
- const Stage& stage_;
- // The attach spec, may not contain op.
- const Stage& attach_spec_;
- // domain map
- const std::unordered_map<IterVar, Range>& dom_map_;
- // Whether keep trivial loops with extent of 1 during lowering.
- // This is a debug feature for dataflow/axis analysis
- bool debug_keep_trivial_loop_;
-};
-
-// inject the operator's realization on the stmt.
-class InjectScanStep : public StmtMutator {
- public:
- InjectScanStep(const Stage& stage,
- const Operation& scan_op,
- const std::unordered_map<IterVar, Range>& dom_map,
- bool is_init,
- bool debug_keep_trivial_loop)
- : stage_(stage), scan_op_(scan_op),
- dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
-
- Stmt VisitStmt(const Stmt& input_stmt) final {
- CHECK(input_stmt.defined());
- auto stmt = StmtMutator::VisitStmt(input_stmt);
- // update
- const AttrStmtNode* op = stmt.as<AttrStmtNode>();
- if (op != nullptr &&
- ((op->attr_key == attr::scan_update_scope && !is_init_) ||
- (op->attr_key == attr::scan_init_scope && is_init_))) {
- if (op->node.same_as(scan_op_)) {
- found_attach = true;
- stmt = AttrStmtNode::make(
- op->node, op->attr_key, op->value,
- MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
- }
- }
- return stmt;
- }
-
- // whether attach point is found
- bool found_attach{false};
-
- private:
- // the operations to be carried
- const Stage& stage_;
- const Operation& scan_op_;
- // domain map
- const std::unordered_map<IterVar, Range>& dom_map_;
- // whether it is init.
- bool is_init_;
- // Whether keep trivial loops with extent of 1 during lowering.
- // This is a debug feature for dataflow/axis analysis
- bool debug_keep_trivial_loop_;
-};
-
-// Postprocessing of schedule op
-// Replace the init and update's expression by scan's buffer.
-class SchedulePostProc : public StmtExprMutator {
- public:
- Stmt VisitStmt_(const ProducerConsumerNode* op) final {
- auto it = replace_op_.find(op->func.get());
- if (it != replace_op_.end()) {
- Stmt body = this->VisitStmt(op->body);
- if (it->second.defined()) {
- return ProducerConsumerNode::make(
- it->second, op->is_producer, body);
- } else {
- return body;
- }
- } else {
- return StmtExprMutator::VisitStmt_(op);
- }
- }
- Stmt VisitStmt_(const LetStmtNode* op) final {
- if (!HasSideEffect(op->value)) {
- var_value_[op->var.get()] = this->VisitExpr(op->value);
- return this->VisitStmt(op->body);
- } else {
- return StmtExprMutator::VisitStmt_(op);
- }
- }
-
- Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::loop_scope ||
- op->attr_key == attr::scan_init_scope) {
- return this->VisitStmt(op->body);
- } else if (op->attr_key == attr::scan_update_scope) {
- const ScanOpNode* scan = op->node.as<ScanOpNode>();
- CHECK(scan);
- var_value_[scan->scan_axis->var.get()] = op->value;
- return this->VisitStmt(op->body);
- } else if (op->attr_key == attr::thread_extent) {
- // delete duplicated thread extent attr
- auto it = thread_extent_scope_.find(op->node.get());
- if (it != thread_extent_scope_.end()) {
- CHECK(is_zero(ir::Simplify(it->second - op->value)));
- return this->VisitStmt(op->body);
- } else {
- thread_extent_scope_[op->node.get()] = op->value;
- Stmt ret = StmtExprMutator::VisitStmt_(op);
- thread_extent_scope_.erase(op->node.get());
- return ret;
- }
- } else if (op->attr_key == ir::attr::realize_scope ||
- op->attr_key == ir::attr::double_buffer_scope) {
- auto it = replace_op_.find(op->node.get());
- if (it != replace_op_.end()) {
- if (it->second.defined()) {
- Stmt ret = AttrStmtNode::make(
- it->second, op->attr_key, op->value, op->body);
- return this->VisitStmt(ret);
- } else {
- return this->VisitStmt(op->body);
- }
- }
- } else if (op->attr_key == ir::attr::buffer_bind_scope) {
- Array<ObjectRef> tuple = Downcast<Array<ObjectRef> >(op->node);
- Tensor tensor = Downcast<Tensor>(tuple[1]);
- auto it = replace_op_.find(tensor->op.get());
- if (it != replace_op_.end()) {
- if (it->second.defined()) {
- return AttrStmtNode::make(
- Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
- op->attr_key, op->value, this->VisitStmt(op->body));
- } else {
- return this->VisitStmt(op->body);
- }
- }
- } else if (op->attr_key == ir::attr::buffer_dim_align) {
- Tensor tensor = Downcast<Tensor>(op->node);
- auto it = replace_op_.find(tensor->op.get());
- if (it != replace_op_.end()) {
- if (it->second.defined()) {
- return AttrStmtNode::make(
- it->second.output(tensor->value_index),
- op->attr_key, op->value, this->VisitStmt(op->body));
- } else {
- return this->VisitStmt(op->body);
- }
- }
- }
- return StmtExprMutator::VisitStmt_(op);
- }
-
- Stmt VisitStmt_(const RealizeNode* op) final {
- TensorKey key{op->func, op->value_index};
- auto it = replace_realize_.find(key);
- if (it != replace_realize_.end()) {
- if (it->second.defined()) {
- Stmt ret = RealizeNode::make(
- it->second->op, it->second->value_index,
- op->dtype, op->bounds, op->condition, op->body);
- return this->VisitStmt(ret);
- } else {
- return this->VisitStmt(op->body);
- }
- } else {
- return StmtExprMutator::VisitStmt_(op);
- }
- }
-
- Stmt VisitStmt_(const ProvideNode* op) final {
- TensorKey key{op->func, op->value_index};
- auto it = replace_buffer_.find(key);
- if (it != replace_buffer_.end()) {
- const Tensor& dst = it->second;
- Stmt ret = ProvideNode::make(
- dst->op, dst->value_index, op->value, op->args);
- return this->VisitStmt(ret);
- } else {
- return StmtExprMutator::VisitStmt_(op);
- }
- }
-
- PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->call_type == CallNode::Halide) {
- TensorKey key{op->func, op->value_index};
- auto it = replace_buffer_.find(key);
- if (it != replace_buffer_.end()) {
- const Tensor& dst = it->second;
- PrimExpr ret = CallNode::make(
- op->dtype, dst->op->name, op->args,
- op->call_type, dst->op, dst->value_index);
- return this->VisitExpr(ret);
- }
- }
- return StmtExprMutator::VisitExpr_(op);
- }
-
- PrimExpr VisitExpr_(const VarNode* op) final {
- auto it = var_value_.find(op);
- if (it != var_value_.end()) {
- return it->second;
- } else {
- return GetRef<PrimExpr>(op);
- }
- }
-
- void Init(const Schedule& sch) {
- for (Stage s : sch->stages) {
- for (auto kv : s->iter_var_attrs) {
- // Update bind thread information.
- if (kv.second->bind_thread.defined()) {
- const Var& from = kv.first->var;
- const Var& to = kv.second->bind_thread->var;
- CHECK(!var_value_.count(from.get()));
- var_value_[from.get()] = to;
- }
- }
- // This must be checked for all ops, including scan.
- if (!s->op.same_as(s->origin_op)) {
- for (int i = 0; i < s->op->num_outputs(); ++i) {
- Tensor target = s->origin_op.output(i);
- AddReplace(s->op.output(i), target,
- target, s->origin_op);
- }
- }
- // Specially add replacements for scan op.
- if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
- for (size_t i = 0; i < scan->update.size(); ++i) {
- Tensor t = s->origin_op.output(i);
- AddReplace(scan->init[i], t);
- AddReplace(scan->update[i], t);
- AddReplace(scan->state_placeholder[i], t);
- }
- }
- }
- }
-
- private:
- void AddReplace(Tensor src,
- Tensor dst,
- Tensor repl_realize = Tensor(),
- Operation repl_op = Operation()) {
- TensorKey key{src->op, src->value_index};
- replace_buffer_[key] = dst;
- replace_realize_[key] = repl_realize;
- replace_op_[src->op.get()] = repl_op;
- }
- // The thread extent scope.
- std::unordered_map<const Object*, PrimExpr> thread_extent_scope_;
- // The scan value
- std::unordered_map<const VarNode*, PrimExpr> var_value_;
- // buffer replacement
- std::unordered_map<TensorKey, Tensor> replace_buffer_;
- // buffere realization to be replaced
- std::unordered_map<TensorKey, Tensor> replace_realize_;
- // replace producer consumer.
- std::unordered_map<const Object*, Operation> replace_op_;
-};
-
-Stmt ScheduleOps(
- Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
- Stmt body = Stmt();
- std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
- // scan init and scan updates
- std::unordered_map<Operation, Operation> scan_init;
- for (Stage s : sch->stages) {
- const ScanOpNode* scan = s->op.as<ScanOpNode>();
- if (!scan) continue;
- for (Tensor t : scan->init) {
- if (scan_init.count(t->op)) {
- CHECK(scan_init.at(t->op).same_as(s->op))
- << "Scan init tensor can only belong to one scan";
- } else {
- scan_init[t->op] = s->op;
- }
- }
- }
- // verify correctness of group.
- for (Stage g : sch->groups) {
- CHECK(!g->op.defined());
- CHECK_EQ(g->leaf_iter_vars.size(), 0U);
- }
- // reverse the post DFS order.
- for (size_t i = sch->stages.size(); i != 0; --i) {
- Stage s = sch->stages[i - 1];
- CHECK_NE(s->attach_type, kInline)
- << "call schedule.normalize before scheduleops";
- CHECK(s->op.defined());
- // no need to specify place holder op.
- if (s->op.as<PlaceholderOpNode>()) continue;
- // Remove grouping sugar, get the real attach spec.
- Stage attach_spec = s.GetAttachSpec();
-
- if (scan_init.count(s->op)) {
- CHECK(body.defined());
- InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
- body = mu(std::move(body));
- CHECK(mu.found_attach)
- << "did not find attachment point for scan.init";
- } else if (attach_spec->attach_type == kScanUpdate) {
- // Handle scan update
- CHECK(body.defined());
- InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
- body = mu(std::move(body));
- CHECK(mu.found_attach)
- << "did not find attachment point for scan.update";
- } else if (attach_spec->attach_type == kInlinedAlready) {
- // do nothing
- } else if (attach_spec->attach_type == kGroupRoot) {
- CHECK(!s->group.defined());
- body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
- } else {
- CHECK_EQ(attach_spec->attach_type, kScope);
- CHECK(body.defined());
- InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
- body = mutator(std::move(body));
- CHECK(mutator.found_attach)
- << "did not find attachment point for " << s << " in "
- << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar
- << ", body:\n"
- << body;
- }
- }
- SchedulePostProc post_proc;
- post_proc.Init(sch);
- return post_proc(std::move(body));
-}
-
-} // namespace schedule
-} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Compute Op.
+ * \file compute_op.cc
+ */
+#include <tvm/top/operation.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_functor_ext.h>
+#include <unordered_set>
+#include <string>
+#include <utility>
+#include "compute_op.h"
+#include "op_util.h"
+#include "../schedule/message_passing.h"
+#include "../../arith/compute_expr.h"
+#include "../../arith/interval_set.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<ComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const ComputeOpNode*>(node.get());
+ p->stream << "compute(" << op->name << ", " << op << ")";
+});
+
+TVM_REGISTER_NODE_TYPE(ComputeOpNode);
+
+/// Verify if ComputeOp is valid with respect to Reduce operations.
+static void VerifyComputeOp(const ComputeOpNode *op);
+
+inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
+ return (a->combiner.same_as(b->combiner)) &&
+ (a->source.same_as(b->source)) &&
+ (a->axis.same_as(b->axis)) &&
+ (a->condition.same_as(b->condition));
+}
+
+int ComputeOpNode::num_outputs() const {
+ return body.size();
+}
+
+Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
+ if (reduce_axis.size() == 0) return axis;
+ Array<IterVar> ret = axis;
+ for (IterVar iv : reduce_axis) {
+ ret.push_back(iv);
+ }
+ return ret;
+}
+
+DataType ComputeOpNode::output_dtype(size_t idx) const {
+ CHECK_LT(idx, num_outputs());
+ return body[idx].dtype();
+}
+
+Array<PrimExpr> BaseComputeOpNode::output_shape(size_t idx) const {
+ CHECK_LT(idx, num_outputs());
+ // for now, all outputs of a BaseComputeOp have the same shape
+ Array<PrimExpr> shape;
+ for (const auto& ivar : this->axis) {
+ const Range& r = ivar->dom;
+ shape.push_back(r->extent);
+ }
+ return shape;
+}
+
+Tensor compute(Array<PrimExpr> shape,
+ FCompute fcompute,
+ std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs) {
+ auto op_node = make_object<ComputeOpNode>();
+ // compute dimension.
+ size_t ndim = shape.size();
+ std::vector<IterVar> axis;
+ std::vector<Var> args;
+ for (size_t i = 0; i < ndim; ++i) {
+ std::ostringstream os;
+ os << "ax" << i;
+ axis.emplace_back(IterVarNode::make(
+ Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
+ args.push_back(axis.back()->var);
+ }
+
+ return ComputeOpNode::make(
+ name, tag, attrs, axis, {fcompute(args)}).output(0);
+}
+
+Array<Tensor> compute(Array<PrimExpr> shape,
+ FBatchCompute fcompute,
+ std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs) {
+ auto op_node = make_object<ComputeOpNode>();
+ // compute dimension.
+ size_t ndim = shape.size();
+ std::vector<IterVar> axis;
+ std::vector<Var> args;
+ for (size_t i = 0; i < ndim; ++i) {
+ std::ostringstream os;
+ os << "ax" << i;
+ axis.emplace_back(IterVarNode::make(
+ Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
+ args.push_back(axis.back()->var);
+ }
+
+ Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
+ Array<Tensor> outputs;
+ for (int idx = 0; idx < op->num_outputs(); ++idx) {
+ outputs.push_back(op.output(idx));
+ }
+ return outputs;
+}
+
+Operation ComputeOpNode::make(std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs,
+ Array<IterVar> axis,
+ Array<PrimExpr> body) {
+ if (!attrs.defined()) {
+ attrs = Map<std::string, ObjectRef>();
+ }
+ auto n = make_object<ComputeOpNode>();
+ n->name = std::move(name);
+ n->tag = std::move(tag);
+ n->attrs = std::move(attrs);
+ n->axis = std::move(axis);
+ n->body = std::move(body);
+ if (n->body[0]->IsInstance<ir::ReduceNode>()) {
+ const ir::ReduceNode* reduce = n->body[0].as<ir::ReduceNode>();
+ n->reduce_axis = reduce->axis;
+ }
+ VerifyComputeOp(n.get());
+ return Operation(n);
+}
+
+// The schedule related logics
+Array<Tensor> ComputeOpNode::InputTensors() const {
+ Array<Tensor> ret;
+ std::unordered_set<Tensor> visited;
+ for (auto& e : body) {
+ ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
+ const ir::CallNode *call = n.as<ir::CallNode>();
+ if (call != nullptr && call->func.defined()) {
+ Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+ if (!visited.count(t)) {
+ ret.push_back(t);
+ visited.insert(t);
+ }
+ }
+ });
+ }
+ return ret;
+}
+
+Operation ComputeOpNode::ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
+ CHECK_EQ(self.operator->(), this);
+ VerifyComputeOp(this);
+ Array<PrimExpr> arr;
+ if (this->body[0]->IsInstance<ir::ReduceNode>()) {
+ // Specially handle reduce so the replaced op
+ // still share all the components
+ PrimExpr new_reduce = top::ReplaceTensor(this->body[0], rmap);
+ if (!new_reduce.same_as(this->body[0])) {
+ const ir::ReduceNode* r = new_reduce.as<ir::ReduceNode>();
+ for (size_t k = 0; k < this->body.size(); ++k) {
+ auto n = make_object<ir::ReduceNode>(*r);
+ n->value_index = static_cast<int>(k);
+ n->dtype = r->source[k].dtype();
+ arr.push_back(PrimExpr(n));
+ }
+ } else {
+ arr = this->body;
+ }
+ } else {
+ arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) {
+ return top::ReplaceTensor(e, rmap);
+ });
+ }
+ if (!arr.same_as(this->body)) {
+ return ComputeOpNode::make(
+ this->name, this->tag, this->attrs, this->axis, arr);
+ } else {
+ return self;
+ }
+}
+
+void ComputeOpNode::PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+ CHECK_EQ(self.operator->(), this);
+ auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
+ auto *call = n.as<ir::CallNode>();
+ if (call != nullptr && call->func.defined()) {
+ Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+ if (t->op.defined() && out_dom_map->count(t)) {
+ TensorDom& dom = out_dom_map->at(t);
+ for (size_t i = 0; i < t.ndim(); ++i) {
+ // We assume that the value of the argument cannot be out of bounds (otherwise it is
+ // undefined behaviour), so we can intersect the estimated set of the argument with the
+ // range expected by the tensor. However, intersection may result in overly complex
+ // expressions, so we perform a more relaxed form of intersection.
+ IntSet arg_intset = EvalSet(call->args[i], dom_map);
+ const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
+ if (arg_interval) {
+ PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
+ PrimExpr shape_i_max_value = t->shape[i] - 1;
+ PrimExpr min_value = arg_interval->min_value;
+ PrimExpr max_value = arg_interval->max_value;
+ // Prefer the shape bounds only when we can prove they are tighter.
+ if (arith::is_neg_inf(min_value) ||
+ analyzer->CanProve(shape_i_min_value >= min_value)) {
+ min_value = shape_i_min_value;
+ }
+ if (arith::is_pos_inf(max_value) ||
+ analyzer->CanProve(shape_i_max_value <= max_value)) {
+ max_value = shape_i_max_value;
+ }
+ dom.data[i].push_back(IntSet::interval(min_value, max_value));
+ } else {
+ dom.data[i].push_back(arg_intset);
+ }
+ }
+ }
+ }
+ };
+ for (auto& e : body) ir::PostOrderVisit(e, fvisit);
+}
+
+void BaseComputeOpNode::GatherBound(
+ const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const {
+ CHECK_EQ(self.operator->(), this);
+ const TensorDom& tdom = tensor_dom.at(self.output(0));
+ for (size_t i = 0; i < this->axis.size(); ++i) {
+ Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
+ CHECK(!out_dom_map->count(this->axis[i]));
+ (*out_dom_map)[this->axis[i]] = r;
+ }
+ for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
+ CHECK(!out_dom_map->count(this->reduce_axis[i]));
+ (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
+ }
+}
+
+Stmt BaseComputeOpNode::BuildRealize(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const {
+ CHECK_EQ(stage->op.get(), this);
+ Region bounds;
+ for (IterVar iv : this->axis) {
+ bounds.push_back(realize_map.at(iv));
+ }
+ Stmt realize = body;
+ for (int i = this->num_outputs(); i > 0; --i) {
+ Tensor t = stage->op.output(i-1);
+ realize = ir::RealizeNode::make(t->op, t->value_index,
+ t->dtype, bounds, const_true(), realize);
+ // alignment requirement, only useful for compute
+ for (size_t i = 0; i < num_schedulable_dims(); ++i) {
+ auto it = stage->iter_var_attrs.find(this->axis[i]);
+ if (it != stage->iter_var_attrs.end()) {
+ IterVarAttr attr = (*it).second;
+ if (attr->dim_align_factor != 0) {
+ Array<PrimExpr> tuple = {static_cast<int>(i),
+ attr->dim_align_factor,
+ attr->dim_align_offset};
+ realize = ir::AttrStmtNode::make(
+ t, ir::attr::buffer_dim_align,
+ CallNode::make(DataType::Handle(),
+ ir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic),
+ realize);
+ }
+ }
+ }
+ }
+ return realize;
+}
+
+size_t ComputeOpNode::num_schedulable_dims() const {
+ return axis.size();
+}
+
+// Build a reduction body.
+void MakeReduction(const ComputeOpNode* op,
+ const Array<Tensor>& tensors,
+ Stmt* init,
+ Stmt* provide) {
+ Array<PrimExpr> args;
+ for (IterVar iv : op->axis) {
+ args.push_back(iv->var);
+ }
+ std::vector<Stmt> inits, provides;
+
+ size_t size = op->body.size();
+ const ReduceNode* reduce = op->body[0].as<ReduceNode>();
+ CHECK(reduce);
+ const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
+ CHECK(combiner);
+ Array<PrimExpr> lhs;
+ for (size_t i = 0; i < size; ++i) {
+ lhs.push_back(tensors[i](args));
+ }
+ Array<PrimExpr> init_value = combiner->identity_element;
+ Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
+ for (size_t i = 0; i < size; ++i) {
+ Tensor t = tensors[i];
+ inits.emplace_back(ProvideNode::make(
+ t->op, t->value_index, init_value[i], args));
+ provides.emplace_back(ProvideNode::make(
+ t->op, t->value_index, update_value[i], args));
+ }
+ *init = SeqStmt::Flatten(inits);
+ *provide = SeqStmt::Flatten(provides);
+ if (!is_one(reduce->condition)) {
+ *provide = IfThenElseNode::make(reduce->condition, *provide);
+ }
+}
+
+// Normal computation.
+Stmt MakeProvide(const ComputeOpNode* op,
+ const Tensor& t) {
+ Array<PrimExpr> args;
+ for (IterVar iv : op->axis) {
+ args.push_back(iv->var);
+ }
+ return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args);
+}
+
+Stmt MakeComputeStmt(const ComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) {
+ // grab the nest structure
+ ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
+ // Normal loop structure
+ n.init_nest.emplace_back(MakeIfNest(n.init_predicates));
+ n.main_nest.emplace_back(MakeIfNest(n.main_predicates));
+ if (self->reduce_axis.size() != 0) {
+ // make reduction.
+ Stmt init, provide;
+ Array<Tensor> source;
+ for (size_t i = 0; i < self->body.size(); ++i) {
+ source.push_back(stage->op.output(i));
+ }
+ MakeReduction(self, source, &init, &provide);
+ init = MergeNest(n.init_nest, init);
+ init = Substitute(init, n.init_vmap);
+ // common nest
+ std::vector<std::vector<Stmt> > common(
+ n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
+ std::vector<std::vector<Stmt> > reduce(
+ n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
+ provide = MergeNest(reduce, provide);
+ if (debug_keep_trivial_loop) {
+ provide = MergeNest(common, provide);
+ } else {
+ provide = MergeNest(common, SeqStmt::Flatten(init, provide));
+ }
+ // run substitution in the on the full nest, because loop condition
+ // could depend on outer loops.
+ return Substitute(provide, n.main_vmap);
+ } else {
+ std::vector<Stmt> provides;
+ for (size_t i = 0; i < self->body.size(); ++i) {
+ provides.emplace_back(MakeProvide(self, stage->op.output(i)));
+ }
+ Stmt provide = SeqStmt::Flatten(provides);
+ provide = MergeNest(n.main_nest, provide);
+ // run substitution in the on the full nest, because loop condition
+ // could depend on outer loops.
+ return Substitute(provide, n.main_vmap);
+ }
+}
+
+enum class ComputeType {
+ kNormal,
+ kCrossThreadReduction,
+ kTensorize
+};
+
+ComputeType DetectComputeType(const ComputeOpNode* self,
+ const Stage& stage) {
+ // Verify correctness of leaf nest.
+ int normal_red = 0, thread_red = 0, tensorize = 0;
+
+ for (IterVar iv : stage->leaf_iter_vars) {
+ IterVarAttr attr;
+ auto it = stage->iter_var_attrs.find(iv);
+ if (it != stage->iter_var_attrs.end()) {
+ attr = (*it).second;
+ }
+ if (attr.defined() && attr->iter_type == kTensorized) {
+ ++tensorize;
+ }
+ if (iv->iter_type == kCommReduce) {
+ if (attr.defined() && attr->bind_thread.defined()) {
+ ++thread_red;
+ } else {
+ ++normal_red;
+ }
+ } else {
+ CHECK_EQ(thread_red, 0)
+ << "Cross thread reduce cannot swap with normal data axis";
+ }
+ }
+ if (tensorize != 0) {
+ CHECK(thread_red == 0)
+ << "Cannot mix cross thread reduction with Tensorize";
+ return ComputeType::kTensorize;
+ }
+ CHECK(normal_red == 0 || thread_red == 0)
+ << "Cannot mix normal reduction with thread reduce";
+ if (thread_red != 0) {
+ return ComputeType::kCrossThreadReduction;
+ } else {
+ return ComputeType::kNormal;
+ }
+}
+
+// implement the provide utility.
+Stmt ComputeOpNode::BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
+ CHECK_EQ(stage->op.operator->(), this);
+ ComputeType ctype = DetectComputeType(this, stage);
+ if (ctype == ComputeType::kCrossThreadReduction) {
+ // specially handle cross thread reduction.
+ return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
+ } else if (ctype == ComputeType::kTensorize) {
+ return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
+ } else {
+ return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
+ }
+}
+
+ComputeLoopNest ComputeLoopNest::make(
+ const BaseComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) {
+ CHECK_EQ(stage->op.operator->(), self);
+ ComputeLoopNest ret;
+ // make main loop nest
+ ret.main_nest = MakeLoopNest(
+ stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
+ debug_keep_trivial_loop);
+ ret.main_predicates = MakeBoundCheck(
+ stage, dom_map, ret.main_vmap, false,
+ std::unordered_set<IterVar>());
+ for (auto& e : ret.main_predicates) {
+ e = likely(e);
+ }
+ if (stage->store_predicate.defined()) {
+ ret.main_predicates.push_back(stage->store_predicate);
+ }
+ if (self->reduce_axis.size() != 0) {
+ // try to find the location to insert the initialization.
+ // Fuse the initialization and provide loop when possible.
+ std::unordered_map<IterVar, int> update_state;
+ for (IterVar iv : self->reduce_axis) {
+ update_state[iv] = 2;
+ }
+ for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
+ update_state[self->axis[i]] = 1;
+ }
+ // find which iter var is related to reduction and which is related to axis.
+ top::PassDownBitMaskOr(stage, &update_state);
+ auto leaf_iter_vars = stage->leaf_iter_vars;
+ // first first loop that is related to reduction.
+ size_t begin_loop = leaf_iter_vars.size();
+ for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
+ auto iv = leaf_iter_vars[i];
+ int flag = update_state.at(iv);
+ if ((flag & 2) != 0) {
+ begin_loop = i; break;
+ }
+ ret.init_vmap[iv] = ret.main_vmap.at(iv);
+ }
+ ret.num_common_loop = begin_loop;
+ // skip loops that are related to reduction and are unrelated to axis.
+ std::unordered_set<IterVar> skip_iter;
+ for (auto kv : update_state) {
+ int flag = kv.second;
+ if (flag == 2) skip_iter.insert(kv.first);
+ }
+ ret.init_nest = MakeLoopNest(
+ stage, dom_map, begin_loop, true,
+ skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
+ ret.init_predicates = MakeBoundCheck(
+ stage, dom_map, ret.init_vmap, true, skip_iter);
+ for (auto& e : ret.init_predicates) {
+ e = likely(e);
+ }
+ } else {
+ CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
+ ret.num_common_loop = stage->leaf_iter_vars.size();
+ }
+ // copy elison here.
+ return ret;
+}
+
+namespace {
+/*!
+ * \brief Verify if ComputeOp is valid with respect to Reduce operations.
+ *
+ * The following two properties are verified:
+ * (1) All Reduce operations must exist at top level.
+ * (2) For a list of operations, if one is Reduce, then the others
+ * must be Reduce as well; and their inputs should have the
+ * same attribute except value_index.
+ */
+class ComputeVerifier final : protected ir::ExprVisitor {
+ public:
+ /// Special member functions
+ //@{
+ explicit ComputeVerifier(const ComputeOpNode* compute)
+ : compute_(compute), reduce_(compute->body[0].as<ir::ReduceNode>()) {}
+ virtual ~ComputeVerifier() = default;
+ ComputeVerifier(const ComputeVerifier&) = delete;
+ ComputeVerifier(ComputeVerifier&&) = delete;
+ ComputeVerifier& operator=(const ComputeVerifier&) = delete;
+ ComputeVerifier& operator=(ComputeVerifier&&) = delete;
+ //@}
+
+ /// Interface to perform compute verification
+ void Run() {
+ for (const PrimExpr e : compute_->body) {
+ // Check for consistency of top level reductions
+ const ir::ReduceNode* reduce = e.as<ir::ReduceNode>();
+ CHECK((reduce && reduce_) || (!reduce && !reduce_))
+ << "All ComputeOp should be consistent "
+ << "with being Reduce operation or not.";
+
+ if (reduce && reduce_) {
+ CHECK(ReduceEqual(reduce, reduce_))
+ << "The Reduce inputs of ComputeOp should "
+ << "have the same attribute except value_index";
+ }
+
+ level_ = 0;
+ ExprVisitor::VisitExpr(e);
+ }
+ }
+
+ protected:
+ /// Visitor implementation
+ //@{
+ void VisitExpr(const PrimExpr& n) final {
+ ++level_;
+ ExprVisitor::VisitExpr(n);
+ --level_;
+ }
+
+ void VisitExpr_(const ir::ReduceNode* op) final {
+ // Check for non top level reductions
+ CHECK(0 == level_)
+ << "Reductions are only allowed at the top level of compute. "
+ << "Please create another tensor for further composition.";
+ }
+ //@}
+
+ private:
+ const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
+ const ir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation
+ int level_{0}; ///< Level of op being processed
+};
+} // namespace
+
+/// Verify if ComputeOp is valid with respect to Reduce operations.
+static void VerifyComputeOp(const ComputeOpNode* op) {
+ ComputeVerifier v(op);
+ v.Run();
+}
+
+Stmt TransformUpdate(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const ComputeLoopNest& n,
+ Stmt body,
+ Stmt update) {
+ Array<PrimExpr> conds;
+ std::unordered_set<const VarNode*> banned;
+ for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
+ IterVar iv = stage->leaf_iter_vars[i];
+ auto iit = stage->iter_var_attrs.find(iv);
+ if (iit != stage->iter_var_attrs.end()) {
+ const IterVarAttr& attr = (*iit).second;
+ if (attr->iter_type == kTensorized) {
+ break;
+ }
+ }
+ if (iv->iter_type == kCommReduce) {
+ auto vit = dom_map.find(iv);
+ CHECK(vit != dom_map.end());
+ const Range& vrange = vit->second;
+ conds.push_back(likely(iv->var > vrange->min));
+ banned.insert(iv->var.get());
+ }
+ }
+ for (const PrimExpr& pred : n.main_predicates) {
+ if (ir::ExprUseVar(pred, banned)) {
+ LOG(FATAL) << "Tensorize update transform failed, the condition "
+ << pred << " has a conflict with the reset condition";
+ }
+ }
+
+ return IfThenElseNode::make(arith::ComputeReduce<ir::OrNode>(conds, const_true(1)),
+ update, body);
+}
+
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Helper utilities to implement compute_op.
+ * \file compute_op.h
+ */
+#ifndef TVM_TOP_OPERATION_COMPUTE_OP_H_
+#define TVM_TOP_OPERATION_COMPUTE_OP_H_
+
+#include <tvm/ir.h>
+#include <tvm/expr.h>
+#include <tvm/top/operation.h>
+#include <vector>
+#include <unordered_map>
+
+namespace tvm {
+namespace top {
+// loop nest structure for general compute
+// This the loop nest structured used in compute.
+// Does not include the loop body.
+struct ComputeLoopNest {
+ // The common number of loops between init and main
+ size_t num_common_loop;
+ // predicates for the initialize loop
+ std::vector<PrimExpr> init_predicates;
+ // Initialization nest involved.
+ std::vector<std::vector<Stmt> > init_nest;
+ // Value map for the init code
+ std::unordered_map<IterVar, PrimExpr> init_vmap;
+ // Predicates for the main update loop
+ std::vector<PrimExpr> main_predicates;
+ // The general loop nest
+ std::vector<std::vector<Stmt> > main_nest;
+ // Value map for the IterVar.
+ std::unordered_map<IterVar, PrimExpr> main_vmap;
+
+ /*!
+ * \brief constructor to build ComputeOpNest
+ * \param self The pointer to compute op.
+ * \param stage The scxhedule stage.
+ * \param dom_map The domain map.
+ * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
+ * \return The constructed loop nest
+ */
+ static ComputeLoopNest make(
+ const BaseComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop);
+};
+
+/*!
+ * \brief Build body of compute for cross thread reduction pattern.
+ * \param self The pointer to ComputeOpNode
+ * \param stage The schedule stage.
+ * \param dom_map The domain map.
+ * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
+ * \return The created statement.
+ */
+Stmt MakeCrossThreadReduction(
+ const ComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop);
+
+/*!
+ * \brief Build body of compute for tensorization.
+ * \param self The pointer to ComputeOpNode
+ * \param stage The schedule stage.
+ * \param dom_map The domain map.
+ * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
+ * \return The created statement.
+ */
+Stmt MakeTensorize(const ComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop);
+
+/*!
+ * \brief Transform the update part when there is no init func in tensorizing
+ * \param stage The stage for tensorizing.
+ * \param dom_map The range of each iter var.
+ * \param n The loop nest structured used in compute.
+ * \param body The body func in tensorize intrin
+ * \param update The update func in tensorize intrin
+ * \return Transformed result.
+ */
+Stmt TransformUpdate(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const ComputeLoopNest& n,
+ Stmt body,
+ Stmt update);
+} // namespace top
+} // namespace tvm
+
+#endif // TVM_TOP_OPERATION_COMPUTE_OP_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Logics related to cross thread reduction, used by ComputeOpNode.
+ * \file cross_thread_reduction.cc
+ */
+#include <tvm/ir_pass.h>
+#include "compute_op.h"
+#include "op_util.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+
+Stmt MakeCrossThreadReduction(
+ const ComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) {
+ Array<PrimExpr> args;
+ for (IterVar iv : self->axis) {
+ args.push_back(iv->var);
+ }
+ std::unordered_map<IterVar, PrimExpr> value_map;
+ auto nest = MakeLoopNest(
+ stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop);
+ auto conds = MakeBoundCheck(
+ stage, dom_map, value_map, false,
+ std::unordered_set<IterVar>());
+
+ size_t size = self->body.size();
+ CHECK_GT(size, 0);
+ std::vector<const ReduceNode*> reduces(size);
+ for (size_t i = 0; i < size; ++i) {
+ const ReduceNode* reduce = self->body[i].as<ReduceNode>();
+ CHECK(reduce);
+ reduces[i] = reduce;
+ }
+ PrimExpr cond = reduces[0]->condition;
+ for (PrimExpr v : conds) {
+ cond = cond && v;
+ }
+ Array<PrimExpr> freduce_args;
+ freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
+ for (size_t i = 0; i < size; ++i) {
+ freduce_args.push_back(reduces[0]->source[i]);
+ }
+ freduce_args.push_back(cond);
+ std::vector<Var> res_handles(size);
+ for (size_t idx = 0; idx < size; ++idx) {
+ res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle());
+ freduce_args.push_back(res_handles[idx]);
+ }
+
+ for (IterVar iv : stage->leaf_iter_vars) {
+ if (iv->iter_type == kCommReduce) {
+ auto it = stage->iter_var_attrs.find(iv);
+ if (it != stage->iter_var_attrs.end() &&
+ (*it).second->bind_thread.defined()) {
+ IterVar tv = (*it).second->bind_thread;
+ freduce_args.push_back(tv->var);
+ }
+ }
+ }
+ // Checks for the thread.
+ std::vector<PrimExpr> thread_head_check;
+ if (stage->store_predicate.defined()) {
+ thread_head_check.emplace_back(stage->store_predicate);
+ }
+
+ Stmt reduce_body = EvaluateNode::make(CallNode::make(
+ DataType::Handle(),
+ ir::intrinsic::tvm_thread_allreduce,
+ freduce_args, CallNode::Intrinsic));
+ reduce_body = AttrStmtNode::make(
+ reduces[0]->combiner,
+ attr::reduce_scope,
+ make_zero(DataType::Handle()),
+ reduce_body);
+ std::vector<Stmt> assigns(size);
+ for (size_t idx = 0; idx < size; ++idx) {
+ DataType t = reduces[idx]->dtype;
+ assigns[idx] = ProvideNode::make(
+ stage->op, idx,
+ LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
+ }
+ Stmt assign_body = SeqStmt::Flatten(assigns);
+ assign_body = MergeNest(MakeIfNest(thread_head_check), assign_body);
+ assign_body = MergeNest(MakeIfNest(conds), assign_body);
+ Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
+ for (size_t idx = size; idx != 0; --idx) {
+ body = AllocateNode::make(
+ res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+ body = AttrStmtNode::make(
+ res_handles[idx - 1], attr::storage_scope, StringImmNode::make("local"), body);
+ }
+ body = Substitute(body, value_map);
+ return MergeNest(nest, body);
+}
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief External computation rule.
+ * \file extern_op.cc
+ */
+#include <tvm/top/operation.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir.h>
+#include <unordered_set>
+#include "op_util.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+// ExternOpNode
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<ExternOpNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const ExternOpNode*>(node.get());
+ p->stream << "extern(" << op->name << ", " << op << ")";
+ });
+
+TVM_REGISTER_NODE_TYPE(ExternOpNode);
+
+int ExternOpNode::num_outputs() const {
+ return static_cast<int>(output_placeholders.size());
+}
+
+Array<IterVar> ExternOpNode::root_iter_vars() const {
+ return {};
+}
+
+DataType ExternOpNode::output_dtype(size_t i) const {
+ return output_placeholders[i]->dtype;
+}
+
+Array<PrimExpr> ExternOpNode::output_shape(size_t i) const {
+ return output_placeholders[i]->shape;
+}
+
+
+Operation ExternOpNode::make(std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs,
+ Array<Tensor> inputs,
+ Array<Buffer> input_placeholders,
+ Array<Buffer> output_placeholders,
+ Stmt body) {
+ if (!attrs.defined()) {
+ attrs = Map<std::string, ObjectRef>();
+ }
+ auto n = make_object<ExternOpNode>();
+ n->name = std::move(name);
+ n->tag = std::move(tag);
+ n->attrs = std::move(attrs);
+ CHECK_EQ(inputs.size(), input_placeholders.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
+ CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size());
+ for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) {
+ CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim]));
+ }
+ CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
+ }
+ n->inputs = std::move(inputs);
+ n->input_placeholders = std::move(input_placeholders);
+ n->output_placeholders = std::move(output_placeholders);
+ n->body = std::move(body);
+ return Operation(n);
+}
+
+Array<Tensor> ExternOpNode::InputTensors() const {
+ return inputs;
+}
+
+Operation ExternOpNode::ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
+ CHECK_EQ(self.operator->(), this);
+ auto n = make_object<ExternOpNode>(*this);
+ n->body = ReplaceTensor(this->body, rmap);
+ for (size_t i = 0; i < n->inputs.size(); ++i) {
+ Tensor t = n->inputs[i];
+ if (rmap.count(t)) {
+ n->inputs.Set(i, rmap.at(t));
+ }
+ }
+
+ if (body.same_as(n->body) &&
+ inputs.same_as(n->inputs)) {
+ return self;
+ } else {
+ return Operation(n);
+ }
+}
+
+void ExternOpNode::PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+ for (Tensor t : this->inputs) {
+ auto it = out_dom_map->find(t);
+ if (it == out_dom_map->end()) continue;
+ TensorDom& dom = it->second;
+ for (size_t i = 0; i < t->shape.size(); ++i) {
+ dom.data[i].emplace_back(IntSet::range(
+ Range::make_by_min_extent(
+ make_const(t->shape[i].dtype(), 0), t->shape[i])));
+ }
+ }
+}
+
+void ExternOpNode::GatherBound(
+ const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const {
+}
+
+Stmt ExternOpNode::BuildRealize(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const {
+ CHECK_EQ(stage->op.get(), this);
+ Stmt realize_body = body;
+ for (int k = 0; k < num_outputs(); ++k) {
+ Tensor t = stage->op.output(k);
+ Region bounds;
+ for (size_t i = 0; i < t->shape.size(); ++i) {
+ bounds.push_back(
+ Range::make_by_min_extent(
+ make_const(t->shape[i].dtype(), 0), t->shape[i]));
+ }
+ realize_body = ir::RealizeNode::make(
+ t->op, t->value_index, t->dtype,
+ bounds, const_true(), realize_body);
+ }
+ return realize_body;
+}
+
+Stmt ExternOpNode::BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
+ CHECK_EQ(stage->op.operator->(), this);
+ Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
+ auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
+ Array<ObjectRef> bind_spec;
+ Array<PrimExpr> tuple;
+ bind_spec.push_back(buffer);
+ bind_spec.push_back(tensor);
+ for (size_t k = 0; k < buffer->shape.size(); ++k) {
+ tuple.push_back(make_const(buffer->shape[k].dtype(), 0));
+ tuple.push_back(buffer->shape[k]);
+ }
+ ret = AttrStmtNode::make(
+ bind_spec, attr::buffer_bind_scope,
+ CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
+ };
+ for (size_t i = output_placeholders.size(); i != 0; --i) {
+ f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
+ }
+ for (size_t i = inputs.size(); i != 0; --i) {
+ f_push_bind(input_placeholders[i - 1], inputs[i - 1]);
+ }
+ return ret;
+}
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Hybrid computation rule.
+ * \file hybrid_op.cc
+ */
+#include <tvm/top/operation.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/ir_pass.h>
+#include <tvm/expr_operator.h>
+#include <unordered_set>
+#include <string>
+#include <utility>
+#include "op_util.h"
+#include "hybrid_op.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+// HybridOpNode
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<HybridOpNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const HybridOpNode*>(node.get());
+ p->stream << "hybrid(" << op->name << ", " << op << ")";
+ });
+
+TVM_REGISTER_NODE_TYPE(HybridOpNode);
+
+int HybridOpNode::num_outputs() const {
+ return static_cast<int>(outputs.size());
+}
+
+Array<IterVar> HybridOpNode::root_iter_vars() const {
+ return this->axis;
+}
+
+DataType HybridOpNode::output_dtype(size_t i) const {
+ return outputs[i]->dtype;
+}
+
+Array<PrimExpr> HybridOpNode::output_shape(size_t i) const {
+ return outputs[i]->shape;
+}
+
+
+Operation HybridOpNode::make(std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs,
+ Array<Tensor> inputs,
+ Array<Tensor> outputs,
+ Stmt body) {
+ if (!attrs.defined()) {
+ attrs = Map<std::string, ObjectRef>();
+ }
+ auto n = make_object<HybridOpNode>();
+ n->name = std::move(name);
+ n->tag = std::move(tag);
+ n->attrs = std::move(attrs);
+ n->inputs = std::move(inputs);
+ n->outputs = std::move(outputs);
+ n->axis = top::GatherLoopVars(body);
+ n->body = std::move(body);
+ Operation res = Operation(n);
+ return res;
+}
+
+Array<Tensor> HybridOpNode::InputTensors() const {
+ // Because input tensors could be potentially inlined into hybrid scripts,
+ // we need to check if all input tensors are used in the body.
+ std::unordered_set<Tensor> orig_inputs;
+ for (auto t : inputs) {
+ orig_inputs.insert(t);
+ }
+ std::unordered_set<Tensor> visited;
+ Array<Tensor> curr_inputs;
+ ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
+ const ir::CallNode *call = n.as<ir::CallNode>();
+ if (call != nullptr && call->func.defined()) {
+ Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+ if (orig_inputs.count(t) && !visited.count(t)) {
+ curr_inputs.push_back(t);
+ visited.insert(t);
+ }
+ }
+ });
+ return curr_inputs;
+}
+
+Operation HybridOpNode::ReplaceInputs(
+ const Operation &self,
+ const std::unordered_map<Tensor, Tensor> &rmap) const {
+ CHECK_EQ(self.operator->(), this);
+ auto n = make_object<HybridOpNode>(*this);
+ n->body = top::ReplaceTensor(this->body, rmap);
+ for (size_t i = 0; i < n->inputs.size(); ++i) {
+ Tensor t = n->inputs[i];
+ if (rmap.count(t)) {
+ n->inputs.Set(i, rmap.at(t));
+ }
+ }
+
+ if (body.same_as(n->body) &&
+ inputs.same_as(n->inputs)) {
+ return self;
+ } else {
+ return Operation(n);
+ }
+}
+
+void HybridOpNode::PropBoundToInputs(
+ const Operation &self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet> &dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+ auto curr_inputs = InputTensors();
+ for (Tensor t : curr_inputs) {
+ auto it = out_dom_map->find(t);
+ if (it == out_dom_map->end()) continue;
+ TensorDom &dom = it->second;
+ for (size_t i = 0; i < t->shape.size(); ++i) {
+ dom.data[i].emplace_back(IntSet::range(
+ Range::make_by_min_extent(
+ make_const(t->shape[i].dtype(), 0), t->shape[i])));
+ }
+ }
+}
+
+void HybridOpNode::GatherBound(
+ const Operation &self,
+ const std::unordered_map<Tensor, TensorDom> &tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const {
+ for (auto iter_var : axis) {
+ CHECK(!out_dom_map->count(iter_var));
+ out_dom_map->operator[](iter_var) = iter_var->dom;
+ }
+}
+
+Stmt HybridOpNode::BuildRealize(
+ const Stage &stage,
+ const std::unordered_map<IterVar, Range> &realize_map,
+ const Stmt &body) const {
+ // TODO(@were): Add attribute inject here and remove it from hybrid parser.
+ CHECK_EQ(stage->op.get(), this);
+ Stmt realize_body = body;
+ for (int k = 0; k < num_outputs(); ++k) {
+ Tensor t = stage->op.output(k);
+ Region bounds;
+ for (size_t i = 0; i < t->shape.size(); ++i) {
+ bounds.push_back(
+ Range::make_by_min_extent(
+ make_const(t->shape[i].dtype(), 0), t->shape[i]));
+ }
+ realize_body = ir::RealizeNode::make(
+ t->op, t->value_index, t->dtype,
+ bounds, const_true(), realize_body);
+ }
+ return realize_body;
+}
+
+Stmt HybridOpNode::BuildProvide(
+ const Stage &stage,
+ const std::unordered_map<IterVar, Range> &dom_map,
+ bool debug_keep_trivial_loop) const {
+ CHECK_EQ(stage->op.operator->(), this);
+ Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
+ std::unordered_map<Tensor, Tensor> rmap;
+ for (int i = 0; i < this->num_outputs(); ++i) {
+ rmap[outputs[i]] = stage->op.output(i);
+ }
+ auto n = make_object<HybridOpNode>(*this);
+ /* This is a story little bit complicated.
+ * The following two lines of codes replace output tensors' usage.
+ * This is the simplest way I (@were) can come up with to glue
+ * hybrid operation node to TVM op system.
+ * In hybrid script all the tensors, especially the output tensors,
+ * have their own names defined by the users. However, In TVM
+ * conventional ops:
+ * 1. Output tensors refer the corresponding op node so that the output
+ * tensors have the same names as the operation produces them.
+ * 2. Once OpNode is wrapped up by an Operation node, it is finalized.
+ * Later access will be from a const OpNode*.
+ * This is a chicken-egg paradox. It is impossible to put the output
+ * tensors into the function body without forming the op node. The
+ * function body is immutable after the node is formed.
+ *
+ * Finally, I decided to resolve this issue "lazily". During the
+ * pipeline of compilation, this stage is a very preliminary stage.
+ * Technically, it is before Phase 0. The actual tensors will be replaced
+ * here.
+ * Thus, the operation body is slightly different from the Phase 0 body.
+ * This is a major difference that HybridOpNode is NOT the same as
+ * ExternOpNode.
+ * */
+ ret = top::ReplaceTensor(ret, rmap);
+ ret = top::ReplaceProvideTensor(ret, rmap);
+
+ ret = top::ApplySchedule(stage, dom_map, ret);
+ return ret;
+}
+
+Stmt ApplyLoopShapes(const Stage &stage,
+ const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
+ class LoopSpliter : public StmtExprMutator {
+ PrimExpr factor;
+ const VarNode *parent;
+ IterVar inner, outer;
+
+ public:
+ bool splitted;
+ LoopSpliter(const SplitNode *split,
+ const std::unordered_map<IterVar, Range> &dom_map) :
+ factor(split->factor), splitted(false) {
+ parent = split->parent->var.get();
+
+ auto &inner_ = split->inner;
+ CHECK(dom_map.count(inner_));
+ auto &inner_dom = dom_map.find(inner_)->second;
+ CHECK(is_const_int(inner_dom->min, 0));
+
+ auto &outer_ = split->outer;
+ CHECK(dom_map.count(outer_));
+ auto &outer_dom = dom_map.find(outer_)->second;
+ CHECK(is_const_int(outer_dom->min, 0));
+
+ inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type);
+ outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
+ }
+
+ Stmt VisitStmt_(const ForNode *op) final {
+ if (op->loop_var.get() == parent) {
+ std::unordered_map<const VarNode *, PrimExpr> rmap;
+ rmap[op->loop_var.get()] = inner + outer * factor;
+ Stmt ret = ir::Substitute(op->body, rmap);
+ PrimExpr cond = likely(outer * factor < (op->extent - inner));
+ ret = IfThenElseNode::make(cond, ret);
+ ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent,
+ IterVarTypeToForType(inner->iter_type), op->device_api, ret);
+ ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent,
+ IterVarTypeToForType(outer->iter_type), op->device_api, ret);
+ splitted = true;
+ return ret;
+ }
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ };
+
+ class LoopFuser : public StmtExprMutator {
+ const IterVar &parent;
+ const VarNode *inner;
+ const VarNode *outer;
+ bool under_outer;
+ PrimExpr extent;
+
+ public:
+ bool fused;
+ explicit LoopFuser(const FuseNode *fuse_)
+ : parent(fuse_->fused), inner(fuse_->inner->var.get()),
+ outer(fuse_->outer->var.get()), under_outer(false),
+ extent(0), fused(false) {}
+
+ // TODO(@were): Handle imperfect loops
+ Stmt VisitStmt_(const ForNode* op) final {
+ if (op->loop_var.get() == inner) {
+ CHECK(under_outer);
+ std::unordered_map<const VarNode *, PrimExpr> rmap;
+ rmap[op->loop_var.get()] = indexmod(parent, op->extent);
+ extent = op->extent;
+ fused = true;
+ return ir::Substitute(op->body, rmap);
+ } else if (op->loop_var.get() == outer) {
+ under_outer = true;
+ Stmt body = this->VisitStmt(op->body);
+ std::unordered_map<const VarNode *, PrimExpr> rmap;
+ rmap[op->loop_var.get()] = indexdiv(parent, extent);
+ body = ir::Substitute(body, rmap);
+ under_outer = false;
+ return ForNode::make(parent->var, PrimExpr(0), extent * op->extent,
+ op->for_type, op->device_api, body);
+ } else if (under_outer) {
+ Stmt body = this->VisitStmt(op->body);
+ std::unordered_map<const VarNode *, PrimExpr> rmap;
+ rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
+ body = ir::Substitute(body, rmap);
+ extent = extent * op->extent;
+ return body;
+ }
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ };
+
+ for (auto &rel : stage->relations) {
+ if (const SplitNode *split = rel.as<SplitNode>()) {
+ LoopSpliter Spliter(split, dom_map);
+ stmt = Spliter(stmt);
+ CHECK(Spliter.splitted);
+ } else if (const FuseNode *fuse = rel.as<FuseNode>()) {
+ LoopFuser Fuser(fuse);
+ stmt = Fuser(stmt);
+ CHECK(Fuser.fused);
+ }
+ }
+
+ return stmt;
+}
+
+Stmt ApplyLoopAnnotations(const Stage &stage,
+ const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
+ class LoopAnnotator : public StmtMutator {
+ const VarNode *var;
+ const IterVarAttr &attr;
+
+ public:
+ LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
+
+ Stmt VisitStmt_(const ForNode *op) final {
+ if (op->loop_var.get() == var) {
+ if (attr->bind_thread.defined()) {
+ const auto &iter_var = attr->bind_thread;
+ if (iter_var->dom.defined()) {
+ CHECK(is_const_int(iter_var->dom->min, 0));
+ CHECK(Equal(iter_var->dom->extent, op->extent))
+ << "Thread extent and loop extent mismatch!\n";
+ }
+ std::unordered_map<const VarNode *, PrimExpr> rmap;
+ rmap[op->loop_var.get()] = iter_var;
+ Stmt body = ir::Substitute(op->body, rmap);
+ return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
+ } else {
+ return ForNode::make(op->loop_var, op->min, op->extent,
+ IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
+ }
+ }
+ return StmtMutator::VisitStmt_(op);
+ }
+ };
+
+ for (auto &iter_var : stage->leaf_iter_vars) {
+ bool need_change = false;
+ int found = 0;
+
+ const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
+ const VarNode *var = actual->var.get();
+ ForType expected = IterVarTypeToForType(iter_var->iter_type);
+ IterVarAttr attr;
+ if (stage->iter_var_attrs.count(iter_var)) {
+ attr = stage->iter_var_attrs[iter_var];
+ expected = IterVarTypeToForType(attr->iter_type);
+ }
+
+ PostOrderVisit(stmt,
+ [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
+ if (const ForNode *op = node.as<ForNode>()) {
+ if (op->loop_var.get() == var) {
+ ++found;
+ need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
+ }
+ }
+ });
+
+ CHECK_EQ(found, 1) << " iter var should be found exactly once!";
+ if (need_change) {
+ stmt = LoopAnnotator(var, attr)(std::move(stmt));
+ }
+ }
+ return stmt;
+}
+
+Stmt ApplyLoopOrder(const Stage &stage,
+ const std::unordered_map<IterVar, Range> &dom_map,
+ const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
+ std::vector<const VarNode*> current_order;
+ PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) {
+ if (const ForNode *op = node.as<ForNode>())
+ current_order.push_back(op->loop_var.get());
+ });
+ std::reverse(current_order.begin(), current_order.end());
+ auto &required_ord = stage->leaf_iter_vars;
+ CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
+ std::unordered_map<const VarNode *, IterVar> reorder;
+ bool need_reorder = false;
+ for (size_t i = 0; i < current_order.size(); ++i) {
+ auto ¤t = current_order[i];
+ const IterVar &iter_var = required_ord[i];
+ const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
+ CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n";
+ reorder[current] = required;
+ if (current != required->var.get()) {
+ need_reorder = true;
+ }
+ }
+
+ class LoopReorder : public StmtMutator {
+ const Stage &stage;
+ const std::unordered_map<IterVar, Range> &dom_map;
+ const std::unordered_map<const VarNode *, IterVar> &reorder;
+
+ public:
+ LoopReorder(const Stage &stage,
+ const std::unordered_map<IterVar, Range> &dom_map,
+ const std::unordered_map<const VarNode*, IterVar> &reorder)
+ : stage(stage), dom_map(dom_map), reorder(reorder) {}
+
+ Stmt VisitStmt_(const ForNode* op) final {
+ // Reorder from in to out
+ Stmt body_ = this->VisitStmt(op->body);
+ CHECK(reorder.count(op->loop_var.get()));
+ auto target = reorder.find(op->loop_var.get())->second;
+ if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
+ return GetRef<Stmt>(op);
+ const Stmt &body = op->body.same_as(body_) ? op->body : body_;
+ ForType for_type = IterVarTypeToForType(target->iter_type);
+ if (stage->iter_var_attrs.count(target)) {
+ for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
+ }
+ const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
+ return ForNode::make(target->var, range->min, range->extent,
+ for_type, DeviceAPI::None, body);
+ }
+ };
+
+ if (need_reorder)
+ return LoopReorder(stage, dom_map, reorder)(stmt);
+
+ return stmt;
+}
+
+Stmt ApplySchedule(const Stage &stage,
+ const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
+ // TODO(@were): Eliminate loop rebase in script parser and move the burden here
+ // Gather rebased variables
+ std::unordered_map<IterVar, IterVar> rebased;
+ for (auto rel : stage->relations) {
+ if (const auto* rebase = rel.as<RebaseNode>()) {
+ rebased[rebase->rebased] = rebase->parent;
+ CHECK(rebase->parent->dom.defined());
+ CHECK(dom_map.count(rebase->rebased));
+ }
+ }
+ stmt = ApplyLoopShapes(stage, dom_map, stmt);
+ stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt);
+ stmt = ApplyLoopAnnotations(stage, rebased, stmt);
+ return stmt;
+}
+
+std::vector<IterVar> GatherLoopVars(Stmt stmt) {
+ // TODO(@were): Write a comprehensive pass to analyze iter var types
+ std::vector<IterVar> res_;
+ PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
+ if (const ForNode *op = node.as<ForNode>()) {
+ Var loop_var(op->loop_var);
+ Range dom = Range::make_by_min_extent(op->min, op->extent);
+ res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type)));
+ }
+ });
+ std::reverse(res_.begin(), res_.end());
+ return res_;
+}
+
+// replacer to replace tensors' usage in Provide
+class ProviderReplacer : public ir::StmtMutator {
+ public:
+ explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
+ : vmap_(vmap) {}
+
+ Stmt VisitStmt_(const ir::ProvideNode* op) final {
+ Tensor t = Downcast<Operation>(op->func).output(op->value_index);
+ auto it = vmap_.find(t);
+ if (it != vmap_.end()) {
+ Stmt ret = ir::ProvideNode::make(
+ it->second->op, it->second->value_index, op->value, op->args);
+ found = true;
+ return this->VisitStmt(ret);
+ }
+ return StmtMutator::VisitStmt_(op);
+ }
+
+ // whether it is found.
+ bool found{false};
+
+ private:
+ const std::unordered_map<Tensor, Tensor> &vmap_;
+};
+
+Stmt ReplaceProvideTensor(Stmt stmt,
+ const std::unordered_map<Tensor, Tensor> &replace) {
+ ProviderReplacer repl(replace);
+ Stmt ret = repl(stmt);
+ return repl.found ? ret : stmt;
+}
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Helper utilities to implement hybrid_op.
+ * \file hybrid_op.h
+ */
+#ifndef TVM_TOP_OPERATION_HYBRID_OP_H_
+#define TVM_TOP_OPERATION_HYBRID_OP_H_
+
+#include <tvm/expr.h>
+#include <tvm/top/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "../schedule/message_passing.h"
+#include "../../pass/ir_util.h"
+#include "../../pass/arg_binder.h"
+
+namespace tvm {
+namespace top {
+
+/*!
+ * \brief Find all the iteration variables in the given statement body.
+ * \param stmt The body to be inspected.
+ */
+std::vector<IterVar> GatherLoopVars(Stmt stmt);
+
+/*!
+ * \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
+ * \param stmt The statement to be processed.
+ * \param replace The replacement rule.
+ */
+Stmt ReplaceProvideTensor(Stmt stmt,
+ const std::unordered_map<Tensor, Tensor>& replace);
+
+/*!
+ * \brief Apply the schedule manipulation on the function body.
+ * \param stmt The statement to be processed.
+ * \param dom_map The extents of the iterative variables may be used.
+ * \param stage The schedule information to be applied.
+ */
+Stmt ApplySchedule(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
+
+/*!
+ * \brief Apply loop splits and fuses in the schedule on the function body.
+ * \param stage The schedule information to be applied.
+ * \param dom_map The extents of the iterative variables may be used.
+ * \param stmt The statement to be processed.
+ */
+Stmt ApplyLoopShapes(const Stage &stage,
+ const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
+
+
+/*!
+ * \brief Apply loop annotation in the schedule on the function body.
+ * \param stage The schedule information to be applied.
+ * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
+ * \param stmt The statement to be processed.
+ */
+Stmt ApplyLoopAnnotations(const Stage &stage,
+ const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt);
+
+/*!
+ * \brief Apply loop order in the schedule on the function body.
+ * \param stage The schedule information to be applied.
+ * \param dom_map The extents of the iterative variables may be used.
+ * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
+ * \param stmt The statement to be processed.
+ */
+Stmt ApplyLoopOrder(const Stage &stage,
+ const std::unordered_map<IterVar, Range> &dom_map,
+ const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt);
+
+} // namespace top
+} // namespace tvm
+
+#endif // TVM_TOP_OPERATION_HYBRID_OP_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Utility to make loop nest.
+ * \file op_util.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/top/operation.h>
+#include <string>
+#include "op_util.h"
+#include "../schedule/message_passing.h"
+#include "../../arith/compute_expr.h"
+
+namespace tvm {
+namespace top {
+
+using namespace arith;
+using namespace ir;
+
+std::vector<std::vector<Stmt> >
+MakeLoopNest(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ size_t begin_iter_pos,
+ bool new_loop_var,
+ const std::unordered_set<IterVar>& skip_iter,
+ std::unordered_map<IterVar, PrimExpr>* p_value_map,
+ bool debug_keep_trivial_loop) {
+ auto leaf_iter_vars = stage->leaf_iter_vars;
+ Stmt no_op = EvaluateNode::make(0);
+ // create the loop nest
+ std::vector<std::vector<Stmt> > nest;
+ nest.resize(leaf_iter_vars.size() + 1);
+ std::unordered_map<IterVar, PrimExpr>& value_map = *p_value_map;
+
+ for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
+ auto iv = leaf_iter_vars[i];
+ if (skip_iter.count(iv) || iv->iter_type == kOpaque) {
+ // skip this iteration.
+ value_map[iv] = iv->var;
+ continue;
+ }
+ // Bind iv could be another thread.
+ IterVar bind_iv = iv;
+ if (stage->iter_var_attrs.count(iv)) {
+ IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
+ if (bind_thread.defined()) bind_iv = bind_thread;
+ }
+
+ Range dom = dom_map.at(iv);
+
+ // initialize the offset and loop_level
+ Var var = bind_iv->var;
+
+ // Mark the iter var in the IR, to remember the point
+ if (bind_iv->thread_tag.length() == 0) {
+ // Only generate new loop if we're not bound to a thread.
+ if (new_loop_var) {
+ var = Var(iv->var->name_hint + ".init", bind_iv->var.dtype());
+ }
+
+ ForType for_type = ForType::Serial;
+ IterVarAttr it_attr;
+ if (stage->iter_var_attrs.count(iv)) {
+ it_attr = stage->iter_var_attrs[iv];
+ }
+ if (it_attr.defined()) {
+ switch (it_attr->iter_type) {
+ case kUnrolled: for_type = ForType::Unrolled; break;
+ case kVectorized: for_type = ForType::Vectorized; break;
+ case kParallelized: for_type = ForType::Parallel; break;
+ case kDataPar: break;
+ case kTensorized: break;
+ default: LOG(FATAL) << "Unknown iter type"
+ << it_attr->iter_type
+ << " in the iter_var_attrs";
+ }
+ CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
+ for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
+ const std::string& pkey = it_attr->pragma_keys[k].as<StringImmNode>()->value;
+ PrimExpr pvalue = it_attr->pragma_values[k];
+ if (!pvalue.defined()) {
+ pvalue = make_const(DataType::Int(32), 1);
+ }
+ nest[i + 1].emplace_back(
+ AttrStmtNode::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
+ }
+ }
+ if (!debug_keep_trivial_loop && is_one(dom->extent)) {
+ nest[i + 1].emplace_back(
+ LetStmtNode::make(var, dom->min, no_op));
+ value_map[iv] = dom->min;
+ } else if (is_zero(dom->min)) {
+ nest[i + 1].emplace_back(
+ ForNode::make(var, 0, dom->extent,
+ for_type, DeviceAPI::None, no_op));
+ value_map[iv] = var;
+ } else {
+ Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
+ nest[i + 1].emplace_back(
+ ForNode::make(idx, 0, dom->extent,
+ for_type, DeviceAPI::None, no_op));
+ PrimExpr new_value = dom->min + idx;
+ value_map[iv] = new_value;
+ nest[i + 1].emplace_back(
+ LetStmtNode::make(var, new_value, no_op));
+ }
+ if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
+ CHECK(!is_one(dom->extent))
+ << "Cannot prefetch on trivial loop with extent=1";
+ CHECK_EQ(it_attr->prefetch_data.size(),
+ it_attr->prefetch_offset.size());
+ for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
+ nest[i + 1].emplace_back(
+ AttrStmtNode::make(it_attr->prefetch_data[j],
+ ir::attr::prefetch_scope,
+ it_attr->prefetch_offset[j], no_op));
+ }
+ }
+ } else if (bind_iv->thread_tag == "vthread" ||
+ bind_iv->thread_tag == "cthread") {
+ // virtual thread
+ // Always restrict threaded IterVar to starts from 0.
+ CHECK(is_zero(dom->min));
+ CHECK(is_positive_const(dom->extent));
+ // annotate the extent of the IterVar
+ nest[i + 1].emplace_back(
+ AttrStmtNode::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
+ value_map[iv] = var;
+ } else if (bind_iv->thread_tag == "pipeline") {
+ // pipeline marker.
+ CHECK(is_zero(dom->min));
+ CHECK(is_one(dom->extent));
+ // annotate the extent of the IterVar
+ nest[i + 1].emplace_back(
+ AttrStmtNode::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
+ value_map[iv] = dom->min;
+ } else {
+ // Always restrict threaded IterVar to starts from 0.
+ CHECK(is_zero(dom->min));
+ // annotate the extent of the IterVar
+ nest[i + 1].emplace_back(
+ AttrStmtNode::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
+ if (!debug_keep_trivial_loop && is_one(dom->extent)) {
+ value_map[iv] = dom->min;
+ } else {
+ value_map[iv] = var;
+ }
+ }
+ // annotate the extent of the IterVar
+ if (!new_loop_var) {
+ nest[i + 1].emplace_back(
+ AttrStmtNode::make(iv, attr::loop_scope, iv->var, no_op));
+ }
+ }
+ // message passing to get offset of root iter vars.
+ top::PassUpIndex(stage, dom_map, &value_map);
+ return nest;
+}
+
+std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates) {
+ Stmt no_op = EvaluateNode::make(0);
+ std::vector<Stmt> nest;
+ for (const PrimExpr& cond : predicates) {
+ nest.emplace_back(IfThenElseNode::make(cond, no_op));
+ }
+ return nest;
+}
+
+// replacer to replace tensors
+class TensorReplacer : public ir::StmtExprMutator {
+ public:
+ explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
+ : vmap_(vmap) {}
+
+ PrimExpr VisitExpr_(const ir::CallNode* op) final {
+ if (op->call_type == ir::CallNode::Halide) {
+ Tensor t = Downcast<Operation>(op->func).output(op->value_index);
+ auto it = vmap_.find(t);
+ if (it != vmap_.end()) {
+ PrimExpr ret = ir::CallNode::make(
+ op->dtype, it->second->op->name, op->args,
+ op->call_type, it->second->op, it->second->value_index);
+ found = true;
+ return this->VisitExpr(ret);
+ }
+ }
+ return StmtExprMutator::VisitExpr_(op);
+ }
+
+ // whether it is found.
+ bool found{false};
+
+ private:
+ const std::unordered_map<Tensor, Tensor>& vmap_;
+};
+
+Stmt ReplaceTensor(Stmt stmt,
+ const std::unordered_map<Tensor, Tensor>& replace) {
+ TensorReplacer repl(replace);
+ Stmt ret = repl(stmt);
+ return repl.found ? ret : stmt;
+}
+PrimExpr ReplaceTensor(PrimExpr expr,
+ const std::unordered_map<Tensor, Tensor>& replace) {
+ TensorReplacer repl(replace);
+ PrimExpr ret = repl(expr);
+ return repl.found ? ret : expr;
+}
+
+
+Stmt Substitute(Stmt s,
+ const std::unordered_map<IterVar, PrimExpr>& value_map) {
+ std::unordered_map<const VarNode*, PrimExpr> init;
+ for (const auto& kv : value_map) {
+ init[kv.first->var.get()] = kv.second;
+ }
+ return ir::Substitute(s, init);
+}
+
+IterVarType ForTypeToIterVarType(ir::ForType for_type) {
+ switch (for_type) {
+ case ForType::Serial:
+ return kDataPar;
+ case ForType::Parallel:
+ return kParallelized;
+ case ForType::Vectorized:
+ return kVectorized;
+ case ForType::Unrolled:
+ return kUnrolled;
+ default:
+ return kDataPar;
+ }
+}
+
+ir::ForType IterVarTypeToForType(IterVarType iter_type) {
+ switch (iter_type) {
+ case kDataPar:
+ return ForType::Serial;
+ case kParallelized:
+ return ForType::Parallel;
+ case kVectorized:
+ return ForType::Vectorized;
+ case kUnrolled:
+ return ForType::Unrolled;
+ default:
+ return ForType::Serial;
+ }
+}
+
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file op_util.h
+ * \brief Common utility used in operator construction.
+ */
+#ifndef TVM_TOP_OPERATION_OP_UTIL_H_
+#define TVM_TOP_OPERATION_OP_UTIL_H_
+
+#include <tvm/expr.h>
+#include <tvm/top/schedule.h>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include "../../pass/ir_util.h"
+#include "../../pass/arg_binder.h"
+#include "../schedule/message_passing.h"
+
+namespace tvm {
+namespace top {
+
+using ir::MergeNest;
+
+/*!
+ * \brief Build loop nest for stage.
+ *
+ * \param stage The stage to create a loop nest.
+ * \param dom_map The range of each iter var.
+ * \param begin_iter_pos The beginning position of leaf_iter_vars to generate loop.
+ * \param new_loop_var Whether create new loop variable.
+ * \param skip_iter Whether skip certain iteration.
+ * \param p_value_map The result value of each IterVar.
+ * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
+ */
+std::vector<std::vector<Stmt> >
+MakeLoopNest(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ size_t begin_iter_pos,
+ bool new_loop_var,
+ const std::unordered_set<IterVar>& skip_iter,
+ std::unordered_map<IterVar, PrimExpr>* p_value_map,
+ bool debug_keep_trivial_loop);
+
+/*!
+ * \brief Create a nest of if checking the predicates.
+ *
+ * \param predicates The predicates to be checked.
+ * \return List of If nest that checks the predicates.
+ */
+std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates);
+
+/*!
+ * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
+ * \param stmt The statement to be processed.
+ * \param replace The replacement rule.
+ */
+Stmt ReplaceTensor(Stmt stmt,
+ const std::unordered_map<Tensor, Tensor>& replace);
+/*!
+ * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
+ * \param expr The expression to be processed.
+ * \param replace The replacement rule.
+ */
+PrimExpr ReplaceTensor(PrimExpr expr,
+ const std::unordered_map<Tensor, Tensor>& replace);
+
+/*!
+ * \brief Substitute the variables of stmt by value map.
+ * \param stmt the statment
+ * \param value_map The value map.
+ * \return Substituted result.
+ */
+Stmt Substitute(Stmt stmt,
+ const std::unordered_map<IterVar, PrimExpr>& value_map);
+
+/*!
+ * \brief Converts Halide ForType to its corresponding IterVarType
+ * \param for_type The ForType to be converted
+ */
+IterVarType ForTypeToIterVarType(ir::ForType for_type);
+
+/*!
+ * \brief Converts IterVarType to its corresponding Halide ForType
+ * \param iter_type The IterVarType to be converted
+ */
+ir::ForType IterVarTypeToForType(IterVarType iter_type);
+
+} // namespace top
+} // namespace tvm
+#endif // TVM_TOP_OPERATION_OP_UTIL_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Placeholder op.
+ * \file placeholder_op.cc
+ */
+#include <tvm/top/operation.h>
+
+namespace tvm {
+namespace top {
+
+// PlaceholderOpNode
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const PlaceholderOpNode*>(node.get());
+ p->stream << "placeholder(" << op->name << ", " << op << ")";
+});
+
+TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
+
+int PlaceholderOpNode::num_outputs() const {
+ return 1;
+}
+
+Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
+ return {};
+}
+
+DataType PlaceholderOpNode::output_dtype(size_t i) const {
+ CHECK_EQ(i, 0U);
+ return dtype;
+}
+
+Array<PrimExpr> PlaceholderOpNode::output_shape(size_t i) const {
+ CHECK_EQ(i, 0U);
+ return shape;
+}
+
+Operation PlaceholderOpNode::make(std::string name,
+ Array<PrimExpr> shape,
+ DataType dtype) {
+ auto n = make_object<PlaceholderOpNode>();
+ n->name = name;
+ n->shape = shape;
+ n->dtype = dtype;
+ return Operation(n);
+}
+
+Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
+ return PlaceholderOpNode::make(name, shape, dtype).output(0);
+}
+
+Array<Tensor> PlaceholderOpNode::InputTensors() const {
+ return {};
+}
+
+Operation PlaceholderOpNode::ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
+ return self;
+}
+
+void PlaceholderOpNode::PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+}
+
+void PlaceholderOpNode::GatherBound(
+ const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const {
+}
+
+Stmt PlaceholderOpNode::BuildRealize(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& realize_map,
+ const Stmt& body) const {
+ return body;
+}
+
+Stmt PlaceholderOpNode::BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
+ return Stmt();
+}
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Scan Operator.
+ * \file scan_op.cc
+ */
+#include <tvm/top/operation.h>
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include "op_util.h"
+#include "../schedule/graph.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<ScanOpNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const ScanOpNode*>(node.get());
+ p->stream << "scan(" << op->name << ", " << op << ")";
+});
+TVM_REGISTER_NODE_TYPE(ScanOpNode);
+
+inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) {
+ return is_zero(ir::Simplify(lhs - rhs));
+}
+
+int ScanOpNode::num_outputs() const {
+ return static_cast<int>(update.size());
+}
+Array<IterVar> ScanOpNode::root_iter_vars() const {
+ Array<IterVar> ret{scan_axis};
+ for (IterVar iv : spatial_axis_) {
+ ret.push_back(iv);
+ }
+ return ret;
+}
+
+DataType ScanOpNode::output_dtype(size_t i) const {
+ return update[i]->dtype;
+}
+
+Array<PrimExpr> ScanOpNode::output_shape(size_t i) const {
+ CHECK_LT(i, state_placeholder.size());
+ return state_placeholder[i]->shape;
+}
+
+Operation ScanOpNode::make(std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs,
+ IterVar axis,
+ Array<Tensor> init,
+ Array<Tensor> update,
+ Array<Tensor> state_placeholder,
+ Array<Tensor> inputs) {
+ if (!attrs.defined()) {
+ attrs = Map<std::string, ObjectRef>();
+ }
+ auto n = make_object<ScanOpNode>();
+ CHECK_EQ(init.size(), update.size());
+ CHECK_EQ(init.size(), state_placeholder.size());
+
+ for (size_t i = 0; i < init.size(); ++i) {
+ CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
+ CHECK_EQ(init[i]->dtype, update[i]->dtype);
+ CHECK(prove_equal(init[i]->shape[0], axis->dom->min))
+ << "init.shape[0] need to match scan_axis.dom.min";
+ CHECK(prove_equal(
+ state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
+ << "state_placeholder.shape[0] need to match"
+ << " scan_axis.dom.min + scan_axis.dom.extent";
+ CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
+ << "The dimension of init need to match state_placeholder";
+ CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
+ << "The update.ndim need to be state_placeholder.ndim - 1";
+ for (size_t k = 0; k < update[i].ndim(); ++k) {
+ CHECK(prove_equal(
+ update[i]->shape[k], state_placeholder[i]->shape[k]));
+ if (k != 0) {
+ // setup spatial axis
+ std::ostringstream spatial_name;
+ spatial_name << name << ".out" << i << ".i" << k;
+ n->spatial_axis_.push_back(
+ IterVarNode::make(
+ Range::make_by_min_extent(0, update[i]->shape[k]),
+ Var(spatial_name.str()), kOpaque));
+ }
+ }
+
+ for (size_t k = 1; k < init[i].ndim(); ++k) {
+ CHECK(prove_equal(
+ init[i]->shape[k], state_placeholder[i]->shape[k]));
+ }
+ }
+ n->name = std::move(name);
+ n->tag = std::move(tag);
+ n->attrs = std::move(attrs);
+ n->scan_axis = std::move(axis);
+ n->init = std::move(init);
+ n->update = std::move(update);
+ n->state_placeholder = std::move(state_placeholder);
+ n->inputs = std::move(inputs);
+ return Operation(n);
+}
+
+Array<Tensor> scan(Array<Tensor> init,
+ Array<Tensor> update,
+ Array<Tensor> state_placeholder,
+ Array<Tensor> inputs,
+ std::string name,
+ std::string tag,
+ Map<std::string, ObjectRef> attrs) {
+ IterVar scan_axis =
+ IterVarNode::make(
+ Range::make_by_min_extent(
+ init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
+ Var(name + ".idx"), kOrdered);
+ Operation op = ScanOpNode::make(
+ name, tag, attrs, scan_axis,
+ init, update, state_placeholder, inputs);
+ Array<Tensor> res;
+ for (int i = 0; i < op->num_outputs(); ++i) {
+ res.push_back(op.output(i));
+ }
+ return res;
+}
+
+Array<Tensor> ScanOpNode::InputTensors() const {
+ Array<Tensor> ret;
+ for (Tensor t : init) {
+ ret.push_back(t);
+ }
+ for (Tensor t : update) {
+ ret.push_back(t);
+ }
+ return ret;
+}
+
+Operation ScanOpNode::ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
+ CHECK_EQ(self.operator->(), this);
+ auto n = make_object<ScanOpNode>(*this);
+ for (size_t i = 0; i < n->init.size(); ++i) {
+ if (rmap.count(n->init[i])) {
+ n->init.Set(i, rmap.at(n->init[i]));
+ }
+ if (rmap.count(n->update[i])) {
+ n->update.Set(i, rmap.at(n->update[i]));
+ }
+ }
+ if (!n->init.same_as(init) ||
+ !n->update.same_as(update)) {
+ return Operation(n);
+ } else {
+ return self;
+ }
+}
+
+void ScanOpNode::PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+ CHECK_EQ(self.operator->(), this);
+ for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) {
+ TensorDom* init_dom = nullptr;
+ TensorDom* update_dom = nullptr;
+ if (out_dom_map->count(this->init[i])) {
+ init_dom = &out_dom_map->at(this->init[i]);
+ }
+ if (out_dom_map->count(this->update[i])) {
+ update_dom = &out_dom_map->at(this->update[i]);
+ }
+ // first dimension, always needed.
+ if (init_dom) {
+ init_dom->data[0].push_back(IntSet::range(
+ Range::make_by_min_extent(0, this->init[i]->shape[0])));
+ }
+ if (update_dom) {
+ update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
+ }
+ // The update dimensions
+ for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
+ IterVar sp_ax = this->spatial_axis_[sp_idx];
+ if (init_dom) {
+ init_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
+ }
+ if (update_dom) {
+ update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
+ }
+ }
+ }
+}
+
+void ScanOpNode::GatherBound(
+ const Operation& self,
+ const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+ std::unordered_map<IterVar, Range>* out_dom_map) const {
+ CHECK_EQ(self.operator->(), this);
+ CHECK(!out_dom_map->count(this->scan_axis));
+ std::vector<Tensor> output(this->num_outputs());
+ for (size_t i = 0; i < output.size(); ++i) {
+ output[i] = self.output(i);
+ }
+ // Update for time axis.
+ std::vector<IntSet> time_dom;
+ for (size_t i = 0; i < output.size(); ++i) {
+ const TensorDom& d = tensor_dom.at(output[i]);
+ time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
+ }
+ CHECK(!out_dom_map->count(this->scan_axis));
+ Range sdom = this->scan_axis->dom;
+ Range r = arith::Union(time_dom).cover_range(sdom);
+ (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
+ sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
+ Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self);
+ // Update for spatial axis.
+ size_t sp_idx = 0;
+ for (size_t i = 0; i < output.size(); ++i) {
+ const TensorDom& d = tensor_dom.at(output[i]);
+ for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
+ IterVar sp_ax = this->spatial_axis_[sp_idx];
+ CHECK(!out_dom_map->count(sp_ax));
+ CHECK(fix_pt.count(sp_ax));
+ if (fix_pt[sp_ax].as<ir::IntImmNode>()->value) {
+ // fix point, we can slice it.
+ (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom);
+ } else {
+ // not a fix point, need to include everything.
+ (*out_dom_map)[sp_ax] = sp_ax->dom;
+ }
+ }
+ }
+}
+
+Stmt ScanOpNode::BuildRealize(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const Stmt& body) const {
+ CHECK_EQ(stage->op.get(), this);
+ Range sdom = dom_map.at(this->scan_axis);
+ Range tdom = Range::make_by_min_extent(
+ 0, ir::Simplify(sdom->extent + sdom->min));
+ Stmt ret = body;
+ size_t sp_idx = 0;
+ for (size_t i = 0; i < update.size(); ++i) {
+ Tensor t = stage->op.output(i);
+ CHECK_EQ(static_cast<size_t>(t->value_index), i);
+ Region bounds;
+ bounds.push_back(tdom);
+ for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
+ IterVar sp_ax = this->spatial_axis_[sp_idx];
+ bounds.push_back(dom_map.at(sp_ax));
+ }
+ ret = ir::RealizeNode::make(t->op, t->value_index, t->dtype,
+ bounds, const_true(), ret);
+ }
+ return ret;
+}
+
+Stmt ScanOpNode::BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
+ CHECK_EQ(stage->op.operator->(), this);
+ Stmt provide = AttrStmtNode::make(
+ stage->op, attr::scan_update_scope, this->scan_axis->var,
+ EvaluateNode::make(0));
+ Stmt init = AttrStmtNode::make(
+ stage->op, attr::scan_init_scope, 0,
+ EvaluateNode::make(0));
+ size_t begin_scan = 0;
+ for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
+ if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
+ CHECK_EQ(begin_scan, i);
+ begin_scan = i + 1;
+ }
+ }
+ std::unordered_map<IterVar, PrimExpr> vmap;
+ std::unordered_set<IterVar> empty;
+ auto nest = MakeLoopNest(
+ stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
+ nest[begin_scan].push_back(init);
+ nest.push_back(
+ MakeIfNest(
+ MakeBoundCheck(stage, dom_map, vmap, false, empty)));
+ return MergeNest(nest, provide);
+}
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Tensor Compute Op.
+ * \file tensor_compute_op.cc
+ */
+#include <tvm/top/operation.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <unordered_set>
+#include "./op_util.h"
+#include "./compute_op.h"
+#include "../../arith/compute_expr.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+// TensorComputeOpNode
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const TensorComputeOpNode*>(node.get());
+ p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
+ });
+
+TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);
+
+int TensorComputeOpNode::num_outputs() const {
+ return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
+}
+
+DataType TensorComputeOpNode::output_dtype(size_t i) const {
+ return this->intrin->buffers[this->inputs.size() + i]->dtype;
+}
+
+Operation TensorComputeOpNode::make(std::string name,
+ std::string tag,
+ Array<IterVar> axis,
+ Array<IterVar> reduce_axis,
+ int schedulable_ndim,
+ TensorIntrin intrin,
+ Array<Tensor> tensors,
+ Array<Region> regions,
+ Array<PrimExpr> scalar_inputs) {
+ auto n = make_object<TensorComputeOpNode>();
+ n->name = std::move(name);
+ n->tag = std::move(tag);
+ n->axis = std::move(axis);
+ n->reduce_axis = std::move(reduce_axis);
+ n->schedulable_ndim = std::move(schedulable_ndim);
+ n->intrin = std::move(intrin);
+ n->inputs = std::move(tensors);
+ n->input_regions = std::move(regions);
+ n->scalar_inputs = std::move(scalar_inputs);
+ return Operation(n);
+}
+
+Array<Tensor> TensorComputeOpNode::InputTensors() const {
+ return inputs;
+}
+
+Operation TensorComputeOpNode::ReplaceInputs(
+ const Operation& self,
+ const std::unordered_map<Tensor, Tensor>& rmap) const {
+ CHECK_EQ(self.operator->(), this);
+ auto n = make_object<TensorComputeOpNode>(*this);
+ auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
+ intrin->body = ReplaceTensor(this->intrin->body, rmap);
+ if (intrin->reduce_init.defined()) {
+ intrin->reduce_init = ReplaceTensor(this->intrin->reduce_init, rmap);
+ }
+ if (intrin->reduce_update.defined()) {
+ intrin->reduce_update = ReplaceTensor(this->intrin->reduce_update, rmap);
+ }
+ for (size_t i = 0; i < n->inputs.size(); ++i) {
+ Tensor t = n->inputs[i];
+ if (rmap.count(t)) {
+ n->inputs.Set(i, rmap.at(t));
+ }
+ }
+
+ if (intrin->body.same_as(n->intrin->body) &&
+ intrin->reduce_init.same_as(n->intrin->reduce_init) &&
+ intrin->reduce_update.same_as(n->intrin->reduce_update) &&
+ inputs.same_as(n->inputs)) {
+ return self;
+ } else {
+ n->intrin = TensorIntrin(intrin);
+ return Operation(n);
+ }
+}
+
+void TensorComputeOpNode::PropBoundToInputs(
+ const Operation& self,
+ arith::Analyzer* analyzer,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
+ std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+ for (size_t i = 0; i < this->inputs.size(); ++i) {
+ Tensor t = this->inputs[i];
+ Region region = input_regions[i];
+
+ auto it = out_dom_map->find(t);
+ if (it == out_dom_map->end()) continue;
+ TensorDom& dom = it->second;
+ for (size_t j = 0; j < t.ndim(); ++j) {
+ dom.data[j].emplace_back(EvalSet(region[j], dom_map));
+ }
+ }
+}
+
+size_t TensorComputeOpNode::num_schedulable_dims() const {
+ return schedulable_ndim;
+}
+
+Stmt TensorComputeOpNode::BuildProvide(
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) const {
+ CHECK_EQ(stage->op.operator->(), this);
+
+ // Start bind data.
+ Stmt nop = EvaluateNode::make(0);
+ std::vector<Stmt> input_bind_nest, output_bind_nest;
+ Array<Tensor> inputs = this->InputTensors();
+
+ // input binding
+ size_t num_inputs = inputs.size();
+ for (size_t i = 0; i < num_inputs; ++i) {
+ Tensor tensor = inputs[i];
+ Region region = this->input_regions[i];
+ Buffer buffer = this->intrin->buffers[i];
+ Array<ObjectRef> bind_spec{buffer, tensor};
+
+ Array<PrimExpr> tuple;
+ for (size_t i = 0; i < region.size(); ++i) {
+ tuple.push_back(region[i]->min);
+ tuple.push_back(region[i]->extent);
+ }
+ input_bind_nest.emplace_back(AttrStmtNode::make(
+ bind_spec, ir::attr::buffer_bind_scope,
+ CallNode::make(DataType::Handle(),
+ ir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic), nop));
+ }
+
+ // output binding
+ for (int i = 0; i < this->num_outputs(); ++i) {
+ Tensor tensor = stage->op.output(i);
+ Buffer buffer = this->intrin->buffers[num_inputs + i];
+ Array<ObjectRef> bind_spec{buffer, tensor};
+
+ Array<PrimExpr> tuple;
+ for (size_t i = 0; i < this->axis.size(); ++i) {
+ auto ivar = this->axis[i];
+ if (i < static_cast<size_t>(this->schedulable_ndim)) {
+ tuple.push_back(ivar->var);
+ tuple.push_back(1);
+ } else {
+ Range dom = ivar->dom;
+ tuple.push_back(dom->min);
+ tuple.push_back(dom->extent);
+ }
+ }
+
+ output_bind_nest.emplace_back(AttrStmtNode::make(
+ bind_spec, ir::attr::buffer_bind_scope,
+ CallNode::make(DataType::Handle(),
+ ir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic), nop));
+ }
+
+ // Check variable remap
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
+ ir::ArgBinder binder(&vmap);
+
+ // Map the expressions passed in the call to the TensorIntrin, to the placeholder
+ // variables
+ Array<PrimExpr> user_expr = this->scalar_inputs;
+ Array<Var> scalar_params = this->intrin->scalar_params;
+ Array<PrimExpr> sp_expr;
+ for (auto sp : scalar_params) {
+ PrimExpr esp = sp;
+ sp_expr.push_back(esp);
+ }
+ CHECK_EQ(sp_expr.size(), user_expr.size());
+ // TODO(jdavies-huawei): what name should be used here?
+ binder.BindArray(sp_expr, user_expr, this->name);
+
+ size_t tloc = stage->leaf_iter_vars.size();
+ ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
+
+ if (this->reduce_axis.size() == 0) {
+ std::vector<std::vector<Stmt> > nest(
+ n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
+ nest.emplace_back(MakeIfNest(n.main_predicates));
+ CHECK_EQ(n.init_predicates.size(), 0U);
+ CHECK(this->intrin->body.defined())
+ << "Normal store op for intrin " << this << " is not defined";
+ Stmt body = MergeNest(output_bind_nest, this->intrin->body);
+ body = MergeNest(input_bind_nest, body);
+ body = ir::Substitute(body, vmap);
+ body = MergeNest(binder.asserts(), body);
+ body = top::Substitute(body, n.main_vmap);
+ Stmt ret = MergeNest(nest, body);
+ return ret;
+ } else {
+ // Need to split reduction
+ CHECK(this->intrin->reduce_update.defined())
+ << "Reduction update op is not defined";
+ // Need init and update steps
+ CHECK_NE(this->reduce_axis.size(), 0U);
+ std::vector<std::vector<Stmt> > common(
+ n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
+ std::vector<std::vector<Stmt> > update_nest(
+ n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
+ update_nest.emplace_back(MakeIfNest(n.main_predicates));
+
+ if (this->intrin->reduce_init.defined()) {
+ // init nest
+ std::vector<std::vector<Stmt> > init_nest(
+ n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
+ init_nest.emplace_back(MakeIfNest(n.init_predicates));
+ Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
+ init = top::Substitute(init, n.init_vmap);
+ init = MergeNest(init_nest, init);
+ // The update
+ Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
+ update = MergeNest(input_bind_nest, update);
+ update = ir::Substitute(update, vmap);
+ update = MergeNest(binder.asserts(), update);
+ update = top::Substitute(update, n.main_vmap);
+ update = MergeNest(update_nest, update);
+ return MergeNest(common, SeqStmt::Flatten(init, update));
+ } else {
+ // When init op is not available, use body op for reset in the first iter.
+ CHECK(this->intrin->body.defined())
+ << "Normal body op is not defined";
+ Stmt update = TransformUpdate(stage, dom_map, n,
+ this->intrin->body,
+ this->intrin->reduce_update);
+ update = MergeNest(output_bind_nest, update);
+ update = MergeNest(input_bind_nest, update);
+ update = ir::Substitute(update, vmap);
+ update = MergeNest(binder.asserts(), update);
+ update = top::Substitute(update, n.main_vmap);
+ update = MergeNest(update_nest, update);
+ return MergeNest(common, update);
+ }
+ }
+}
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Logics related to tensorize, used by ComputeOpNode.
+ * \file tensorize.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/ir_pass.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/packed_func_ext.h>
+
+#include "op_util.h"
+#include "compute_op.h"
+#include "../schedule/message_passing.h"
+
+namespace tvm {
+namespace top {
+
+using namespace ir;
+
+// Detect the region of input and output to be tensrized.
+// out_dom: the domain of root iter vars in output op
+// in_region: region of each input tensor.
+// return The location of the tensorized scope start.
+size_t InferTensorizeRegion(
+ const ComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, Range>* out_dom,
+ std::unordered_map<Tensor, Array<Range> >* in_region) {
+ // Get the bound of the tensorized scope.
+ bool found_point = false;
+ size_t loc_scope = 0;
+ std::unordered_map<IterVar, IntSet> up_state;
+ // Loop over the leafs
+ for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) {
+ IterVar iv = stage->leaf_iter_vars[i - 1];
+ CHECK(iv->iter_type == kDataPar ||
+ iv->iter_type == kCommReduce);
+ auto vit = dom_map.find(iv);
+ CHECK(vit != dom_map.end());
+ const Range& vrange = vit->second;
+ if (is_one(vrange->extent)) {
+ up_state[iv] = IntSet::single_point(vrange->min);
+ } else if (found_point) {
+ CHECK(is_zero(vrange->min));
+ up_state[iv] = IntSet::single_point(iv->var);
+ } else {
+ up_state[iv] = IntSet::range(vrange);
+ }
+ auto iit = stage->iter_var_attrs.find(iv);
+ if (iit != stage->iter_var_attrs.end()) {
+ const IterVarAttr& attr = (*iit).second;
+ if (!found_point) {
+ CHECK(!attr->bind_thread.defined())
+ << "Do not allow thread in tensorize scope";
+ }
+ if (attr->iter_type == kTensorized) {
+ CHECK(!found_point) << "Do not allow two tensorized point";
+ found_point = true;
+ loc_scope = i - 1;
+ }
+ }
+ }
+ CHECK(found_point);
+ // Get domain of the tensorized scope.
+ top::PassUpDomain(stage, dom_map, &up_state);
+ // Get domains if inputs
+ std::unordered_map<Tensor, TensorDom> in_dom;
+ std::unordered_map<const VarNode*, IntSet> temp_dmap;
+ arith::Analyzer analyzer;
+ Array<Tensor> inputs = self->InputTensors();
+ for (Tensor t : inputs) {
+ in_dom.emplace(t, TensorDom(t.ndim()));
+ }
+ for (IterVar iv : self->root_iter_vars()) {
+ IntSet iset = up_state.at(iv);
+ Range iv_range = iset.cover_range(dom_map.at(iv));
+ (*out_dom)[iv] = iv_range;
+ analyzer.Bind(iv->var, iv_range);
+ temp_dmap[iv->var.get()] = iset;
+ }
+ // Input domains
+ self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom);
+ Range none;
+ for (const auto& kv : in_dom) {
+ Array<Range> vec;
+ const Tensor& t = kv.first;
+ for (size_t i = 0; i < t.ndim(); ++i) {
+ Range r = arith::Union(kv.second.data.at(i)).cover_range(none);
+ CHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t;
+ vec.push_back(std::move(r));
+ }
+ (*in_region)[t] = std::move(vec);
+ }
+ return loc_scope;
+}
+
+void VerifyTensorizeLoopNest(const ComputeOpNode* self,
+ const Stage& stage,
+ const ComputeLoopNest& n,
+ size_t tloc) {
+ // Veirfication step.
+ std::unordered_set<const VarNode*> banned;
+ CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
+ CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 ||
+ n.init_nest.size() == 0);
+ auto f_push_banned = [&banned](const Stmt& s) {
+ if (const ForNode* op = s.as<ForNode>()) {
+ banned.insert(op->loop_var.get());
+ } else if (const AttrStmtNode* op = s.as<AttrStmtNode>()) {
+ if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
+ banned.insert(iv->var.get());
+ }
+ } else if (const LetStmtNode* op = s.as<LetStmtNode>()) {
+ banned.insert(op->var.get());
+ }
+ };
+ for (size_t i = tloc; i < stage->leaf_iter_vars.size(); ++i) {
+ for (const Stmt& s : n.main_nest[i + 1]) {
+ f_push_banned(s);
+ }
+ if (n.init_nest.size() != 0) {
+ for (const Stmt& s : n.init_nest[i + 1]) {
+ f_push_banned(s);
+ }
+ }
+ }
+ for (const PrimExpr& pred : n.main_predicates) {
+ if (ir::ExprUseVar(pred, banned)) {
+ LOG(FATAL) << "Tensorize failed, split condition "
+ << pred << " relies on var defined inside tensorize scope";
+ }
+ }
+ for (const PrimExpr& pred : n.init_predicates) {
+ if (ir::ExprUseVar(pred, banned)) {
+ LOG(FATAL) << "Tensorize failed, split condition "
+ << pred << " relies on var defined inside tensorize scope";
+ }
+ }
+}
+
+// Remap the tensor placeholder, index and inline things.
+class TensorIntrinMatcher final : public StmtExprMutator {
+ public:
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+ op = expr.as<CallNode>();
+ if (op->call_type == CallNode::Halide) {
+ Tensor t = Downcast<Operation>(op->func).output(op->value_index);
+ auto it = in_remap_.find(t);
+ if (it != in_remap_.end()) {
+ const InputEntry& e = it->second;
+ CHECK_EQ(op->args.size(), e.region.size());
+ Array<PrimExpr> args;
+ for (size_t i = e.start; i < e.region.size(); ++i) {
+ args.push_back(op->args[i] - e.region[i]->min);
+ }
+ return CallNode::make(
+ op->dtype, e.tensor->op->name, args,
+ op->call_type, e.tensor->op, e.tensor->value_index);
+ }
+ }
+ return expr;
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ auto it = var_remap_.find(op);
+ if (it != var_remap_.end()) {
+ return it->second;
+ } else {
+ return GetRef<PrimExpr>(op);
+ }
+ }
+
+ PrimExpr VisitExpr_(const ReduceNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+ op = expr.as<ReduceNode>();
+ Array<IterVar> axis;
+ for (size_t i = 0; i < op->axis.size(); ++i) {
+ auto it = axis_remap_.find(op->axis[i]);
+ if (it != axis_remap_.end()) {
+ axis.push_back(it->second);
+ }
+ }
+ return ReduceNode::make(
+ op->combiner, op->source, axis, op->condition, op->value_index);
+ }
+
+ void Init(const ComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, Range>& out_dom,
+ const std::unordered_map<Tensor, Array<Range> >& in_region,
+ const TensorIntrin& intrin,
+ Map<Var, Range>* compute_intrin_iter_space) {
+ CHECK(self == stage->op.get());
+
+ for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
+ IterVar iv = stage->leaf_iter_vars[i];
+ auto vit = dom_map.find(iv);
+ if (vit != dom_map.end()) {
+ const Range vrange = vit->second;
+ compute_intrin_iter_space->Set(iv->var, vrange);
+ }
+ }
+
+ // input remap.
+ Array<Tensor> inputs = self->InputTensors();
+ CHECK_EQ(inputs.size(), intrin->inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ InputEntry e;
+ e.tensor = intrin->inputs[i];
+ e.region = Array<Range>(in_region.at(inputs[i]));
+ CHECK_GE(e.region.size(), e.tensor.ndim());
+ // Enable fuzzy matching, to match [1, n, m] to [n, m]
+ e.start = e.region.size() - e.tensor.ndim();
+ for (size_t j = 0; j < e.start; ++j) {
+ auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space);
+ CHECK(is_one(canonical_extent))
+ << "Tensorize " << intrin->name << ":"
+ << " Input dimension mismatch with tensor intrin "
+ << " expected shape=" << e.tensor->shape
+ << ", given region=" << e.region;
+ }
+ in_remap_[inputs[i]] = e;
+ }
+ // output remap
+ const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
+ CHECK(intrin_compute) << "Only support compute intrinsic for now";
+ CHECK_GE(self->axis.size(), intrin_compute->axis.size())
+ << "Tensorize: Output mismatch with tensor intrin ";
+ // Enable fuzzy matching, to match [1, n, m] to [n, m]
+ size_t axis_start = self->axis.size() - intrin_compute->axis.size();
+ for (size_t i = 0; i < axis_start; ++i) {
+ Range r = out_dom.at(self->axis[i]);
+ CHECK(is_one(r->extent))
+ << "Tensorize: Output mismatch with tensor intrin "
+ << " intrin-dim=" << intrin_compute->axis.size()
+ << ", tensorize-dim=" << self->axis.size();
+ var_remap_[self->axis[i]->var.get()] = r->min;
+ }
+ // Assume we tensorize at regin axis i [min, min + extent)
+ // The corresponding intrinsic axis is j [0, extent)
+ // Remap index i to j + min
+ for (size_t i = axis_start; i < self->axis.size(); ++i) {
+ IterVar iv = self->axis[i];
+ IterVar target_iv = intrin_compute->axis[i - axis_start];
+ Range r = out_dom.at(iv);
+ var_remap_[iv->var.get()] = target_iv->var + r->min;
+ axis_remap_[iv] = target_iv;
+ compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
+ }
+ // Remap reduction axis
+ CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
+ << "Tensorize: Reduction dimension mismatch with tensor intrin";
+ axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
+ for (size_t i = 0; i < axis_start; ++i) {
+ Range r = out_dom.at(self->reduce_axis[i]);
+ CHECK(is_one(r->extent))
+ << "Tensorize: Reduction mismatch with tensor intrin "
+ << " intrin-dim=" << intrin_compute->reduce_axis.size()
+ << ", tensorize-dim=" << self->reduce_axis.size();
+ var_remap_[self->reduce_axis[i]->var.get()] = r->min;
+ }
+ for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
+ IterVar iv = self->reduce_axis[i];
+ IterVar target_iv = intrin_compute->reduce_axis[i - axis_start];
+ Range r = out_dom.at(iv);
+ var_remap_[iv->var.get()] = target_iv->var + r->min;
+ axis_remap_[iv] = target_iv;
+ compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
+ }
+ }
+
+ private:
+ // Input entry
+ struct InputEntry {
+ Tensor tensor;
+ size_t start;
+ Array<Range> region;
+ };
+ // input data remap
+ std::unordered_map<Tensor, InputEntry> in_remap_;
+ // variable remap.
+ std::unordered_map<const VarNode*, PrimExpr> var_remap_;
+ // IterVar remap.
+ std::unordered_map<IterVar, IterVar> axis_remap_;
+};
+
+// Try to match tensor dataflow of the stage with the intrinsic
+Array<PrimExpr> MatchTensorizeBody(
+ const ComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, Range>& out_dom,
+ const std::unordered_map<Tensor, Array<Range> >& in_region,
+ const TensorIntrin& intrin,
+ Map<Var, Range>* compute_intrin_iter_space) {
+ TensorIntrinMatcher matcher;
+ matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
+ Array<PrimExpr> ret;
+ for (PrimExpr expr : self->body) {
+ ret.push_back(matcher(expr));
+ }
+ return ret;
+}
+
+void VerifyTensorizeBody(
+ const ComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, Range>& out_dom,
+ const std::unordered_map<Tensor, Array<Range> >& in_region,
+ const TensorIntrin& intrin) {
+ Map<Var, Range> compute_intrin_iter_space;
+ Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
+ &compute_intrin_iter_space);
+ const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
+ CHECK(intrin_compute) << "Only support compute intrinsic for now";
+ CHECK_EQ(body.size(), intrin_compute->body.size())
+ << "Tensorize failed: body size mismatch";
+ for (size_t i = 0; i < body.size(); ++i) {
+ PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space);
+ lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
+ PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
+ rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
+ if (lhs.dtype() != rhs.dtype()) {
+ LOG(FATAL)
+ << "Failed to match the data type with TensorIntrin "
+ << intrin->name << "'s declaration "
+ << " provided=" << lhs.dtype()
+ << ", intrin=" << rhs.dtype();
+ }
+ CHECK(Equal(lhs, rhs))
+ << "Failed to match the compute with TensorIntrin "
+ << intrin->name << "'s declaration "
+ << " provided= " << lhs
+ << ", intrin= " << rhs;
+ }
+}
+
+Stmt MakeTensorize(const ComputeOpNode* self,
+ const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop) {
+ std::unordered_map<IterVar, Range> out_dom;
+ std::unordered_map<Tensor, Array<Range> > in_region;
+ size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
+ TensorIntrin intrin = stage->iter_var_attrs.at(
+ stage->leaf_iter_vars[tloc])->tensor_intrin;
+ CHECK(intrin.defined());
+ ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
+ VerifyTensorizeLoopNest(self, stage, n, tloc);
+ VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
+ // Start bind data.
+ Stmt nop = EvaluateNode::make(0);
+ std::vector<Stmt> input_bind_nest, output_bind_nest;
+ Array<Tensor> inputs = self->InputTensors();
+ CHECK_EQ(inputs.size(), intrin->inputs.size())
+ << "Tensorize failed: input size mismatch ";
+ // input binding
+ for (size_t i = 0; i < intrin->inputs.size(); ++i) {
+ Tensor tensor = inputs[i];
+ Buffer buffer = intrin->buffers[i];
+ Array<ObjectRef> bind_spec{buffer, tensor};
+ auto it = in_region.find(tensor);
+ CHECK(it != in_region.end());
+ const Array<Range>& region = it->second;
+ Array<PrimExpr> tuple;
+ for (const Range r : region) {
+ tuple.push_back(r->min);
+ tuple.push_back(r->extent);
+ }
+ input_bind_nest.emplace_back(AttrStmtNode::make(
+ bind_spec, ir::attr::buffer_bind_scope,
+ CallNode::make(DataType::Handle(),
+ ir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic), nop));
+ }
+ // output binding
+ const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
+ CHECK(intrin_compute) << "Only support compute intrinsic for now";
+ CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size());
+ CHECK_EQ(intrin_compute->body.size(), self->body.size());
+ Array<PrimExpr> tuple;
+ for (IterVar iv : self->axis) {
+ auto it = out_dom.find(iv);
+ CHECK(it != out_dom.end());
+ tuple.push_back(it->second->min);
+ tuple.push_back(it->second->extent);
+ }
+ for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) {
+ Tensor tensor = stage->op.output(i - intrin->inputs.size());
+ Buffer buffer = intrin->buffers[i];
+ Array<ObjectRef> bind_spec{buffer, tensor};
+ output_bind_nest.emplace_back(AttrStmtNode::make(
+ bind_spec, ir::attr::buffer_bind_scope,
+ CallNode::make(DataType::Handle(),
+ ir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic), nop));
+ }
+ // Check variable remap
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
+ ir::ArgBinder binder(&vmap);
+ CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
+ << "Tensorization fail: reduction axis size do not match";
+ size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
+ for (size_t i = 0; i < start; ++i) {
+ IterVar iv = self->reduce_axis[i];
+ auto it = out_dom.find(iv);
+ CHECK(it != out_dom.end());
+ CHECK(is_one(it->second->extent))
+ << "Tensorization fail: reduction axis size do not match";
+ }
+ for (size_t i = start; i < self->reduce_axis.size(); ++i) {
+ IterVar iv = self->reduce_axis[i];
+ IterVar target = intrin_compute->reduce_axis[i - start];
+ auto it = out_dom.find(iv);
+ CHECK(it != out_dom.end());
+ binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0),
+ "tensir_intrin.reduction.min");
+ binder.Bind(target->dom->extent, it->second->extent,
+ "tensir_intrin.reduction.extent");
+ }
+ if (tloc <= n.num_common_loop) {
+ // Do no need to split reduction
+ std::vector<std::vector<Stmt> > nest(
+ n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
+ nest.emplace_back(MakeIfNest(n.main_predicates));
+ CHECK_EQ(n.init_predicates.size(), 0U);
+ CHECK(intrin->body.defined())
+ << "Normal store op for intrin " << intrin << " is not defined";
+ Stmt body = MergeNest(output_bind_nest, intrin->body);
+ body = MergeNest(input_bind_nest, body);
+ body = ir::Substitute(body, vmap);
+ body = MergeNest(binder.asserts(), body);
+ body = top::Substitute(body, n.main_vmap);
+ return MergeNest(nest, body);
+ } else {
+ // Need to split reduction
+ CHECK(intrin->reduce_update.defined())
+ << "Reduction update op for intrin " << intrin << " is not defined";
+ // Need init and update steps
+ CHECK_NE(self->reduce_axis.size(), 0U);
+ std::vector<std::vector<Stmt> > common(
+ n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
+ std::vector<std::vector<Stmt> > update_nest(
+ n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
+ update_nest.emplace_back(MakeIfNest(n.main_predicates));
+
+ if (intrin->reduce_init.defined()) {
+ // init nest
+ std::vector<std::vector<Stmt> > init_nest(
+ n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
+ init_nest.emplace_back(MakeIfNest(n.init_predicates));
+ Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
+ init = top::Substitute(init, n.init_vmap);
+ init = MergeNest(init_nest, init);
+ // The update
+ Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
+ update = MergeNest(input_bind_nest, update);
+ update = ir::Substitute(update, vmap);
+ update = MergeNest(binder.asserts(), update);
+ update = top::Substitute(update, n.main_vmap);
+ update = MergeNest(update_nest, update);
+ return MergeNest(common, SeqStmt::Flatten(init, update));
+ } else {
+ // When init op is not available, use body op for reset in the first iter.
+ CHECK(intrin->body.defined())
+ << "Normal body op for intrin " << intrin << " is not defined";
+ Stmt update = TransformUpdate(stage, dom_map, n,
+ intrin->body,
+ intrin->reduce_update);
+ update = MergeNest(output_bind_nest, update);
+ update = MergeNest(input_bind_nest, update);
+ update = ir::Substitute(update, vmap);
+ update = MergeNest(binder.asserts(), update);
+ update = top::Substitute(update, n.main_vmap);
+ update = MergeNest(update_nest, update);
+ return MergeNest(common, update);
+ }
+ }
+}
+
+// Register functions for unittests
+TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ Stage stage = args[0];
+ Map<IterVar, Range> dmap = args[1];
+ std::unordered_map<IterVar, Range> out_dom;
+ std::unordered_map<Tensor, Array<Range> > in_region;
+ CHECK(stage->op.as<ComputeOpNode>());
+ InferTensorizeRegion(stage->op.as<ComputeOpNode>(),
+ stage,
+ as_unordered_map(dmap),
+ &out_dom, &in_region);
+ *ret = Array<ObjectRef>{Map<IterVar, Range>(out_dom),
+ Map<Tensor, Array<Range> >(in_region)};
+ });
+
+TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ Stage stage = args[0];
+ Map<IterVar, Range> out_dom = args[1];
+ Map<Tensor, Array<Range> > in_region = args[2];
+ TensorIntrin intrin = args[3];
+ Map<Var, Range> vrange;
+ CHECK(stage->op.as<ComputeOpNode>());
+ *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
+ stage,
+ {{}},
+ as_unordered_map(out_dom),
+ as_unordered_map(in_region),
+ intrin,
+ &vrange);
+ });
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_inline_elem_wise.cc
+ */
+#include <tvm/top/schedule_pass.h>
+#include <tvm/top/operation.h>
+#include <tvm/ir_functor_ext.h>
+
+namespace tvm {
+namespace top {
+
+using namespace ir;
+
+class ElemWiseDetector : public ir::ExprVisitor {
+ public:
+ explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
+
+ void VisitExpr(const PrimExpr& e) final {
+ if (!is_elem_wise_) return;
+ ExprVisitor::VisitExpr(e);
+ }
+
+ void VisitExpr_(const CallNode* op) final {
+ Array<PrimExpr> axis = op->args;
+ if (axis_.size() != axis.size()) {
+ is_elem_wise_ = false;
+ return;
+ }
+
+ for (size_t i = 0; i < axis_.size(); ++i) {
+ if (!axis[i].same_as(axis_[i]->var)) {
+ is_elem_wise_ = false;
+ return;
+ }
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ bool is_elem_wise_{true};
+
+ private:
+ Array<IterVar> axis_;
+};
+
+
+bool IsElemWise(const Operation& op) {
+ if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
+ ElemWiseDetector v = ElemWiseDetector(compute->axis);
+ for (auto& e : compute->body) v(e);
+ return v.is_elem_wise_;
+ }
+ return false;
+}
+
+void AutoInlineElemWise(Schedule sch) {
+ for (Stage s : sch->stages) {
+ if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) {
+ s.compute_inline();
+ }
+ }
+}
+
+bool IsBroadcast(const Operation& op) {
+ if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
+ if (compute->reduce_axis.size()) {
+ return false;
+ }
+ // TODO(nicolasvasilache): Implement Me
+ }
+ return false;
+}
+
+void AutoInlineBroadcast(Schedule sch) {
+ for (Stage s : sch->stages) {
+ if (!s.is_scheduled() && IsBroadcast(s->op) && !s->is_output) {
+ s.compute_inline();
+ }
+ }
+}
+
+bool IsInjective(const Operation& op) {
+ if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
+ return compute->reduce_axis.size() == 0;
+ }
+ return false;
+}
+
+void AutoInlineInjective(Schedule sch) {
+ for (Stage s : sch->stages) {
+ if (!s.is_scheduled() && IsInjective(s->op) && !s->is_output) {
+ s.compute_inline();
+ }
+ }
+}
+
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bound.cc
+ * \brief The bound inference logic.
+ */
+#include <tvm/top/schedule_pass.h>
+#include <tvm/top/operation.h>
+#include <tvm/ir_pass.h>
+#include <unordered_map>
+#include <unordered_set>
+#include "graph.h"
+#include "message_passing.h"
+#include "../../runtime/thread_storage_scope.h"
+
+namespace tvm {
+namespace top {
+
+using runtime::StorageRank;
+using runtime::StorageScope;
+using runtime::ThreadScope;
+
+/*! \brief The graph context used during bound inference. */
+struct GraphContext {
+ /*! \brief The feed graph */
+ FeedGraph feed_graph;
+ /*! \brief Attachment path */
+ AttachPath attach_path;
+ /*! \brief The bind map */
+ std::unordered_map<IterVar, IterVar> bind_map;
+ /*! \brief map from op to stage */
+ std::unordered_map<const Object*, Stage> op2stage_;
+};
+
+bool NeedRelax(const IterVar& iv,
+ bool found_attach,
+ const std::unordered_map<IterVar, IterVar>& bind_map,
+ const runtime::StorageScope& scope) {
+ auto it = bind_map.find(iv);
+ const std::string& tag = (
+ it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
+ if (tag.length() == 0 || tag == "pipeline") {
+ return !found_attach;
+ }
+ ThreadScope ts = ThreadScope::make(tag);
+
+ // When there is warp memory
+ // threadIdx.x must be set to be warp index.
+ if (scope.rank == StorageRank::kWarp &&
+ ts.rank == 1 &&
+ ts.dim_index == 0) {
+ return true;
+ }
+ return static_cast<int>(scope.rank) <= ts.rank;
+}
+
+// infer storage scope, if not given
+StorageScope InferStorageScope(
+ const Stage& stage, const GraphContext& ctx) {
+ if (stage->scope.length() != 0) {
+ return StorageScope::make(stage->scope);
+ }
+ int max_rank = -1;
+ for (IterVar iv : ctx.attach_path.at(stage->op)) {
+ auto it = ctx.bind_map.find(iv);
+ const std::string& tag = (
+ it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
+ if (tag != "pipeline" && tag.length() != 0) {
+ max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
+ }
+ }
+ StorageScope s;
+ s.rank = runtime::DefaultStorageRank(max_rank);
+ return s;
+}
+
+
+void InferRootBound(const Stage& stage,
+ const GraphContext& ctx,
+ std::unordered_map<IterVar, Range>* rmap) {
+ CHECK_NE(stage->attach_type, kInline)
+ << "call schedule.normalize before scheduleops";
+ if (stage->attach_type == kInlinedAlready) return;
+ if (stage->is_output) {
+ // verify correctness.
+ CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
+ << "Output must be attached at root";
+ }
+ if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
+ for (auto iv : stage->op->root_iter_vars()) {
+ CHECK(iv->dom.defined());
+ CHECK(!rmap->count(iv));
+ (*rmap)[iv] = iv->dom;
+ }
+ return;
+ }
+ // The tensor domain.
+ std::unordered_map<Tensor, TensorDom> tmap;
+ // The consumers of the op.
+ std::unordered_set<Operation> consumers;
+ for (int i = 0; i < stage->op->num_outputs(); ++i) {
+ Tensor t = stage->op.output(i);
+ tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
+ auto it = ctx.feed_graph.find(t);
+ if (it != ctx.feed_graph.end()) {
+ for (const Operation& op : it->second) {
+ consumers.insert(op);
+ }
+ } else {
+ LOG(INFO) << "not in feed graph consumer = " << stage->op;
+ }
+ }
+ // storage scope.
+ runtime::StorageScope scope = InferStorageScope(stage, ctx);
+ // Bound prop by other consumers.
+ // - Compute bound by relaxation rules: NeedRelax
+ // - For normal index, use relative location of loop nest./
+ // - For thread index, use the thread scope.
+ //
+ Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
+ // The parent set.
+ for (const Operation& op : consumers) {
+ std::unordered_map<const VarNode*, IntSet> relax_set;
+ std::unordered_map<IterVar, IntSet> up_state;
+ bool found_attach = false;
+ CHECK(ctx.op2stage_.count(op.get()));
+ const Stage& op_stage = ctx.op2stage_.at(op.get());
+ // Consumer nest
+ for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) {
+ IterVar iv = op_stage->leaf_iter_vars[i - 1];
+ if (stage_attach.size() != 0 && iv == stage_attach[0]) {
+ found_attach = true;
+ }
+ auto it = rmap->find(iv);
+ CHECK(it != rmap->end());
+ const Range& vrange = it->second;
+ if (is_one(vrange->extent)) {
+ up_state[iv] = IntSet::single_point(vrange->min);
+ } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
+ CHECK(is_zero(vrange->min))
+ << "InferBound requires every leaf iter var's min equals 0, "
+ << " call schedule.normalize to achieve this. ";
+ if (ctx.bind_map.count(iv)) {
+ up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
+ } else {
+ up_state[iv] = IntSet::single_point(iv->var);
+ }
+ } else {
+ up_state[iv] = IntSet::range(vrange);
+ }
+ }
+ // Consumer's attach nest
+ for (IterVar iv : ctx.attach_path.at(op)) {
+ if (stage_attach.size() != 0 && iv == stage_attach[0]) {
+ found_attach = true;
+ }
+ Range vrange = rmap->at(iv);
+ CHECK(is_zero(vrange->min))
+ << "InferBound requires every leaf iter var's min equals 0, "
+ << "call schedule.normalize to achieve this.";
+ if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
+ relax_set[iv->var.get()] = IntSet::range(vrange);
+ if (ctx.bind_map.count(iv)) {
+ relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
+ }
+ }
+ }
+ CHECK(found_attach || stage_attach.size() == 0)
+ << "Invalid Schedule, cannot find the producer " << stage->op
+ << " along the loop nest specified by compute_at of consumer " << op;
+ // Get the domain of the consumer
+ PassUpDomain(op_stage, *rmap, &up_state);
+ // Relax if needed.
+ std::unordered_map<const VarNode*, IntSet> dom_map;
+ arith::Analyzer analyzer;
+ for (auto iv : op->root_iter_vars()) {
+ Range r;
+ if (up_state.count(iv)) {
+ r = up_state.at(iv).cover_range(iv->dom);
+ } else {
+ r = iv->dom;
+ }
+ if (relax_set.size() != 0) {
+ dom_map[iv->var.get()] = EvalSet(r, relax_set);
+ } else {
+ dom_map[iv->var.get()] = IntSet::range(r);
+ }
+ analyzer.Bind(iv->var, r);
+ }
+ op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
+ }
+ stage->op->GatherBound(stage->op, tmap, rmap);
+}
+
+Map<IterVar, Range> InferBound(const Schedule& sch) {
+ // Prepare context
+ GraphContext ctx;
+ Array<Operation> roots;
+ arith::Analyzer analyzer;
+
+ for (Operation op : sch->outputs) {
+ roots.push_back(sch->stage_map[op]->op);
+ }
+ ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
+
+ for (Stage stage : sch->stages) {
+ for (auto kv : stage->iter_var_attrs) {
+ if (kv.second->bind_thread.defined()) {
+ CHECK(!ctx.bind_map.count(kv.first));
+ ctx.bind_map[kv.first] = kv.second->bind_thread;
+ }
+ }
+ ctx.op2stage_[stage->op.get()] = stage;
+ }
+ ctx.attach_path = CreateAttachPath(sch);
+ // Run inference.
+ std::unordered_map<IterVar, Range> ret;
+ for (size_t i = sch->stages.size(); i != 0; --i) {
+ const Stage& stage = sch->stages[i - 1];
+ InferRootBound(stage, ctx, &ret);
+
+ // bind bound of root iter vars.
+ for (auto iv : stage->op->root_iter_vars()) {
+ auto it = ret.find(iv);
+ if (it != ret.end()) {
+ analyzer.Bind(iv->var, it->second);
+ }
+ }
+
+ // pass down to get bound of all iter vars.
+ PassDownDomain(stage, &ret, &analyzer);
+ for (IterVar iv : stage->env_threads) {
+ CHECK(iv->dom.defined());
+ ret[iv] = iv->dom;
+ }
+ }
+ for (auto& p : ret) {
+ ret[p.first] = Range::make_by_min_extent(
+ analyzer.Simplify(p.second->min),
+ analyzer.Simplify(p.second->extent));
+ }
+ return Map<IterVar, Range>(ret.begin(), ret.end());
+}
+
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph.cc
+ * \brief Utilities to get information about schedule graph.
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/top/operation.h>
+#include <utility>
+#include <unordered_set>
+#include <unordered_map>
+#include "graph.h"
+
+namespace tvm {
+namespace top {
+// key to specific tensor dimension.
+struct TensorDimKey {
+ ir::FunctionRef f;
+ int value_index;
+ int dim;
+ TensorDimKey() {}
+ TensorDimKey(const ir::CallNode* op, int dim)
+ : f(op->func), value_index(op->value_index), dim(dim) {
+ }
+ TensorDimKey(const Tensor& t, int dim)
+ : f(t->op), value_index(t->value_index), dim(dim) {
+ }
+ TensorDimKey(const Tensor& t, size_t dim)
+ : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {
+ }
+ inline bool operator==(const TensorDimKey& other) const {
+ return f == other.f &&
+ value_index == other.value_index &&
+ dim == other.dim;
+ }
+ inline bool operator!=(const TensorDimKey& other) const {
+ return !operator==(other);
+ }
+};
+} // namespace top
+} // namespace tvm
+
+namespace std {
+template <>
+struct hash<::tvm::top::TensorDimKey> {
+ std::size_t operator()(const ::tvm::top::TensorDimKey& k) const {
+ size_t lhs = ::tvm::ObjectHash()(k.f);
+ size_t rhs = static_cast<size_t>(k.value_index) << 16UL |
+ static_cast<size_t>(k.dim);
+ lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
+ return lhs;
+ }
+};
+} // namespace std
+
+
+namespace tvm {
+namespace top {
+
+// construct a read graph that gives readers of each operation
+// that the root depend on
+ReadGraph CreateReadGraph(const Array<Operation>& roots) {
+ ReadGraph rmap;
+ std::vector<Operation> stack;
+ std::unordered_set<const Object*> visited;
+ // initialize the roots
+ for (Operation op : roots) {
+ stack.push_back(op);
+ visited.insert(op.get());
+ }
+
+ while (!stack.empty()) {
+ Operation op = stack.back();
+ stack.pop_back();
+ Array<Tensor> deps = op->InputTensors();
+ rmap.Set(op, deps);
+ for (Tensor t : deps) {
+ if (t->op.defined() && visited.count(t->op.get()) == 0) {
+ visited.insert(t->op.get());
+ stack.push_back(t->op);
+ }
+ }
+ }
+ return rmap;
+}
+
+// Do DFS visit to get the subgraph.
+// Return if op is inside the subgraph.
+bool GetSubGraphByPostDFS_(
+ const Operation& op,
+ const std::unordered_set<const Object*>& boundary,
+ bool include_bounary,
+ std::unordered_map<const Object*, bool>* visited,
+ Array<Operation>* result) {
+ if (visited->count(op.get())) {
+ return visited->at(op.get());
+ }
+ if (boundary.count(op.get())) {
+ (*visited)[op.get()] = true;
+ if (include_bounary) {
+ result->push_back(op);
+ }
+ return true;
+ }
+ // mark to avoid loop
+ // Not necessary for DAG.
+ (*visited)[op.get()] = false;
+ // check if we can reach boundary.
+ bool reach_boundary = false;
+ for (Tensor t : op->InputTensors()) {
+ if (GetSubGraphByPostDFS_(t->op, boundary,
+ include_bounary,
+ visited, result)) {
+ reach_boundary = true;
+ }
+ }
+ (*visited)[op.get()] = reach_boundary;
+ if (reach_boundary) {
+ result->push_back(op);
+ }
+ return reach_boundary;
+}
+
+Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
+ const Array<Tensor>& inputs,
+ bool include_inputs) {
+ Array<Operation> result;
+ std::unordered_set<const Object*> boundary;
+ for (Tensor t : inputs) {
+ boundary.insert(t->op.get());
+ }
+ std::unordered_map<const Object*, bool> visited;
+ for (Tensor t : outputs) {
+ GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
+ &visited, &result);
+ }
+ return result;
+}
+
+
+void PostDFSOrder(const Operation& op,
+ const ReadGraph& g,
+ std::unordered_set<Operation>* visited,
+ Array<Operation>* post_order) {
+ if (visited->count(op)) return;
+ visited->insert(op);
+ for (const auto& t : g.at(op)) {
+ PostDFSOrder(t->op, g, visited, post_order);
+ }
+ post_order->push_back(op);
+}
+
+Array<Operation> PostDFSOrder(
+ const Array<Operation>& roots,
+ const ReadGraph& g) {
+ std::unordered_set<Operation> visited;
+ Array<Operation> post_order;
+ for (Operation op : roots) {
+ PostDFSOrder(op, g, &visited, &post_order);
+ }
+ return post_order;
+}
+
+FeedGraph CreateFeedGraph(const ReadGraph& g) {
+ FeedGraph fg;
+ for (auto kv : g) {
+ for (Tensor t : kv.second) {
+ fg[t].push_back(kv.first);
+ }
+ }
+ return fg;
+}
+
+AttachPath CreateAttachPath(Schedule sch) {
+ AttachPath ret;
+ for (Stage stage : sch->stages) {
+ std::unordered_set<const Object*> visited;
+ Array<IterVar> path;
+ for (Stage s = stage; s.defined();) {
+ CHECK(!visited.count(s.get()))
+ << "Find loop in compute_at attach group";
+ visited.insert(s.get());
+ Stage spec = s.GetAttachSpec();
+ bool start_attach;
+ IterVar attach_ivar;
+ if (spec->attach_type == kScope) {
+ attach_ivar = spec->attach_ivar;
+ s = spec->attach_stage;
+ start_attach = false;
+ CHECK(attach_ivar.defined());
+ } else if (spec->attach_type == kScanUpdate) {
+ s = spec->attach_stage;
+ start_attach = true;
+ } else {
+ break;
+ }
+ CHECK(s.defined());
+ for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
+ IterVar iv = s->leaf_iter_vars[i - 1];
+ if (!start_attach && iv.same_as(attach_ivar)) {
+ start_attach = true;
+ }
+ if (start_attach) path.push_back(iv);
+ }
+ CHECK(start_attach)
+ << "Invalid Schedule: cannot find attach point " << attach_ivar
+ << " in the schedule of " << s->op;
+ }
+ if (!ret.count(stage->op)) {
+ ret.Set(stage->op, path);
+ }
+ }
+ return ret;
+}
+
+// graph of push reach relation of tensor dimensions
+using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;
+
+ReachGraph GetReachGraph(const Array<Operation>& ops) {
+ ReachGraph reach;
+ std::unordered_set<const Object*> bset;
+ for (size_t i = 0; i < ops.size(); ++i) {
+ bset.insert(ops[i].get());
+ }
+
+ for (Operation op : ops) {
+ if (const auto* scan_op = op.as<ScanOpNode>()) {
+ const auto& update = scan_op->update;
+ const auto& init = scan_op->init;
+ for (size_t i = 0; i < update.size(); ++i) {
+ Tensor t = op.output(i);
+ for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
+ reach[TensorDimKey(t, k)].emplace_back(
+ TensorDimKey(update[i], k));
+ reach[TensorDimKey(t, k)].emplace_back(
+ TensorDimKey(init[i], k));
+ }
+ }
+ } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
+ std::unordered_map<const Object*, TensorDimKey> vmap;
+ const auto& axis = compute_op->axis;
+ Tensor t = op.output(0);
+ for (size_t i = 0; i < axis.size(); ++i) {
+ vmap[axis[i]->var.get()] = TensorDimKey(t, i);
+ reach[TensorDimKey(t, i)] = {};
+ }
+ auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) {
+ const ir::CallNode *call = n.as<ir::CallNode>();
+ if (call != nullptr && call->func.defined()) {
+ if (!bset.count(call->func.get())) return;
+ for (size_t i = 0; i < call->args.size(); ++i) {
+ TensorDimKey dkey(call, static_cast<int>(i));
+ auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) {
+ const VarNode *v = node.as<VarNode>();
+ auto it = vmap.find(v);
+ if (it != vmap.end()) {
+ reach[it->second].push_back(dkey);
+ }
+ };
+ ir::PostOrderVisit(call->args[i], fpush);
+ }
+ }
+ };
+ for (auto& e : compute_op->body) {
+ ir::PostOrderVisit(e, fvisit);
+ }
+ }
+ }
+ return reach;
+}
+
+Array<Operation> ScanGetBody(const Operation& scan_op) {
+ const ScanOpNode* scan = scan_op.as<ScanOpNode>();
+ // Get the body.
+ Array<Tensor> inputs;
+ for (Tensor t : scan->state_placeholder) {
+ inputs.push_back(t);
+ }
+ for (Tensor t : scan->inputs) {
+ inputs.push_back(t);
+ }
+ return GetSubGraph(scan->update, inputs, false);
+}
+
+Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
+ const ScanOpNode* scan = scan_op.as<ScanOpNode>();
+ Array<Operation> body = ScanGetBody(scan_op);
+
+ std::unordered_map<TensorDimKey, const Object*> exact_reach;
+ std::unordered_set<const Object*> fail_set;
+
+ for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
+ for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
+ TensorDimKey key(scan->state_placeholder[i], k);
+ exact_reach[key] = scan->spatial_axis_[sp_idx].get();
+ }
+ }
+ // merge exact reach
+ auto f_merge_key = [&exact_reach, &fail_set](
+ const TensorDimKey& dst, const TensorDimKey& src) {
+ auto sit = exact_reach.find(src);
+ if (sit == exact_reach.end()) return;
+ auto dit = exact_reach.find(dst);
+ if (dit == exact_reach.end()) {
+ exact_reach[dst] = sit->second;
+ } else {
+ if (dit->second != sit->second) {
+ fail_set.insert(dit->second);
+ fail_set.insert(sit->second);
+ }
+ }
+ };
+ // prop exact reach back.
+ for (size_t i = 0; i < body.size(); ++i) {
+ const Operation& op = body[i];
+ if (const auto* scan_op = op.as<ScanOpNode>()) {
+ const auto& update = scan_op->update;
+ const auto& init = scan_op->init;
+ for (size_t i = 0; i < update.size(); ++i) {
+ Tensor t = op.output(i);
+ for (size_t k = 1; k < update[i]->shape.size(); ++k) {
+ f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
+ f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
+ }
+ }
+ } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
+ std::unordered_map<const Object*, std::vector<TensorDimKey> > vmap;
+ const auto& axis = compute_op->axis;
+ for (size_t i = 0; i < axis.size(); ++i) {
+ std::vector<TensorDimKey> keys;
+ for (int j = 0; j < op->num_outputs(); ++j) {
+ keys.emplace_back(op.output(j), i);
+ }
+ vmap[axis[i]->var.get()] = std::move(keys);
+ }
+ auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
+ const ObjectRef& n) {
+ const ir::CallNode *call = n.as<ir::CallNode>();
+ if (call != nullptr && call->func.defined()) {
+ for (size_t i = 0; i < call->args.size(); ++i) {
+ auto it = vmap.find(call->args[i].get());
+ TensorDimKey src(call, static_cast<int>(i));
+ if (it != vmap.end()) {
+ const std::vector<TensorDimKey>& keys = it->second;
+ for (const auto& key : keys) {
+ f_merge_key(key, src);
+ }
+ } else {
+ if (exact_reach.count(src)) {
+ fail_set.insert(exact_reach.at(src));
+ }
+ }
+ }
+ }
+ };
+ for (auto& e : compute_op->body) {
+ ir::PostOrderVisit(e, fvisit);
+ }
+ }
+ }
+ ReachGraph reach;
+ Map<IterVar, PrimExpr> ret;
+ std::unordered_set<TensorDimKey> place_holder_ref;
+ for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
+ for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
+ place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
+ }
+ }
+
+ for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
+ for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
+ TensorDimKey key(scan->update[i], k);
+ TensorDimKey target(scan->state_placeholder[i], k);
+ IterVar sp_iv = scan->spatial_axis_[sp_idx];
+ if (fail_set.count(sp_iv.get()) ||
+ !exact_reach.count(key) ||
+ exact_reach.at(key) != sp_iv.get()) {
+ ret.Set(sp_iv, make_const(DataType::Int(32), 0));
+ } else {
+ // now we proved exact match, need to prove no interference with other graph.
+ if (reach.size() == 0) reach = GetReachGraph(body);
+ // do a DFS
+ std::unordered_set<TensorDimKey> visited;
+ std::vector<TensorDimKey> stack{key};
+ visited.insert(key);
+ while (!stack.empty()) {
+ TensorDimKey k = stack.back();
+ if (k != target && place_holder_ref.count(k)) break;
+ stack.pop_back();
+ if (!reach.count(k)) {
+ LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
+ }
+
+ for (TensorDimKey kk : reach.at(k)) {
+ if (visited.count(kk)) {
+ continue;
+ }
+ visited.insert(kk);
+ stack.push_back(kk);
+ }
+ }
+ if (!stack.empty()) {
+ // failed the prove.
+ ret.Set(sp_iv, make_const(DataType::Int(32), 0));
+ } else {
+ ret.Set(sp_iv, make_const(DataType::Int(32), 1));
+ }
+ }
+ }
+ }
+ return ret;
+}
+
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph.h
+ * \brief Utilities to get information about schedule graph.
+ */
+#ifndef TVM_TOP_SCHEDULE_GRAPH_H_
+#define TVM_TOP_SCHEDULE_GRAPH_H_
+
+#include <tvm/expr.h>
+#include <tvm/top/schedule.h>
+#include <tvm/top/operation.h>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace top {
+
+/*!
+ * \brief data structure of Operation->Tensors it reads
+ */
+using ReadGraph = Map<Operation, Array<Tensor> >;
+
+/*!
+ * \brief AttachPath maps op-> a list of IterVar
+ */
+using AttachPath = Map<Operation, Array<IterVar> >;
+
+/*!
+ * \brief The map between tensor and operation it feeds to.
+ */
+using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
+
+/*!
+ * \brief Get read graph of each operation to all the
+ * Tensors that it directly depends on.
+ *
+ * The result map contains Operations needed to finish root Operation.
+ * \param roots The root operation.
+ * \return The result map.
+ */
+ReadGraph CreateReadGraph(const Array<Operation>& roots);
+
+/*!
+ * \brief Get minimum subgraph between outputs and inputs.
+ * The operations contains node which input-reachable from any inputs
+ * output reachable to any outputs.
+ *
+ * The inputs won't be included in the subgraph, the outputs will be included.
+ *
+ * \param outputs The outputs of the subgraph
+ * \param inputs The inputs to the subgraph.
+ * \param include_inputs Whether to include inputs
+ *
+ * \return The subgraph.
+ */
+Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
+ const Array<Tensor>& inputs,
+ bool include_inputs);
+
+/*!
+ * \brief Get a post DFS ordered of operations in the graph.
+ * \param roots The root of the graph.
+ * \param g The read graph.
+ * \return vector order of Operations in PostDFS order.
+ *
+ * \note PostDFSOrder is a special case of Topoligical order,
+ * and can be used when topoligical order is needed.
+ */
+Array<Operation> PostDFSOrder(
+ const Array<Operation>& roots, const ReadGraph& g);
+
+/*!
+ * \brief Create feedgraph for given Schedule
+ * \param g The read graph.
+ * \return The created feedgraph.
+ */
+FeedGraph CreateFeedGraph(const ReadGraph& g);
+
+/*!
+ * \brief Create AttachPath that maps op-> a list of IterVar
+ * That represents the loop nest op sits in from inner most to outermost
+ * Also inserts attach_stage for scan updates when needed.
+ *
+ * \param sch The schedule.
+ * \return The attach path.
+ */
+AttachPath CreateAttachPath(Schedule sch);
+
+/*!
+ * \brief Get all operations inside the recursion of scan.
+ * \param scan_op The scan node ops.
+ * \return The body operations, in read dependency order.
+ */
+Array<Operation> ScanGetBody(const Operation& scan_op);
+
+/*!
+ * \brief Analyze each spatial dimension of scan's result.
+ * Give check on whether each dimension is fix point,
+ * An axis is a fixed point if it only refers back to itself in recursion
+ * and it is not used in axis of other recursion field.
+ *
+ * next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
+ *
+ * \param scan The scan node.
+ * \return Map of spatial_axis -> IntImm
+ */
+Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan);
+
+} // namespace top
+} // namespace tvm
+
+#endif // TVM_TOP_SCHEDULE_GRAPH_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file message_passing.cc
+ * \brief The message passing domain.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include "message_passing.h"
+#include "../../arith/compute_expr.h"
+
+namespace tvm {
+namespace top {
+
+using namespace ir;
+
+void Update(std::unordered_map<IterVar, Range>* p_state,
+ const IterVar& iv,
+ Range r,
+ arith::Analyzer* analyzer) {
+ auto it = p_state->find(iv);
+ if (it == p_state->end()) {
+ (*p_state)[iv] = r;
+ analyzer->Bind(iv->var, r);
+ } else {
+ bool match = is_zero(it->second->min) &&
+ analyzer->CanProve(r->extent - it->second->extent == 0);
+ CHECK(match)
+ << iv
+ << " domain already inferred,"
+ << " cannot prove their extents are the same "
+ << it->second->extent << " vs " << r->extent;
+ }
+}
+
+void PassDownDomain(const Stage& stage,
+ std::unordered_map<IterVar, Range>* p_state,
+ arith::Analyzer* actx,
+ bool allow_missing) {
+ auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
+ if (actx->CanProve(indexmod(a, b) == 0)) {
+ return actx->Simplify(indexdiv(a, b));
+ }
+ return actx->Simplify(indexdiv(a + (b - 1), b));
+ };
+
+ auto& state = *p_state;
+ // forwar iteration on relations
+ for (IterVarRelation rel : stage->relations) {
+ if (const SplitNode* r = rel.as<SplitNode>()) {
+ if (!state.count(r->parent)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ CHECK(!state.count(r->inner));
+ const Range& range_parent = state.at(r->parent);
+ if (r->factor.defined()) {
+ Update(p_state, r->inner,
+ Range::make_by_min_extent(0, r->factor), actx);
+ Update(p_state, r->outer,
+ Range::make_by_min_extent(
+ 0, ceil_div(range_parent->extent, r->factor)), actx);
+ } else {
+ Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
+ Update(p_state, r->inner,
+ Range::make_by_min_extent(
+ 0, ceil_div(range_parent->extent, r->nparts)), actx);
+ }
+ } else if (const FuseNode* r = rel.as<FuseNode>()) {
+ if (!state.count(r->outer) || !state.count(r->inner)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ const Range& range_outer = state.at(r->outer);
+ const Range& range_inner = state.at(r->inner);
+ state[r->fused] = Range::make_by_min_extent(
+ 0, range_outer->extent * range_inner->extent);
+ } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
+ if (!state.count(r->parent)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ Update(p_state, r->rebased,
+ Range::make_by_min_extent(
+ 0, state.at(r->parent)->extent), actx);
+ } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
+ Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx);
+ } else {
+ LOG(FATAL) << "unknown relation type";
+ }
+ }
+ // update the extents of binded threads.
+ for (auto kv : stage->iter_var_attrs) {
+ if (kv.second->bind_thread.defined()) {
+ CHECK(state.count(kv.first));
+ Update(p_state, kv.second->bind_thread, state.at(kv.first), actx);
+ }
+ }
+}
+
+void PassUpIndex(const Stage& stage,
+ const Map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, PrimExpr>* p_state,
+ bool allow_missing) {
+ auto& state = *p_state;
+ for (size_t i = stage->relations.size(); i != 0; --i) {
+ IterVarRelation rel = stage->relations[i - 1];
+ if (const SplitNode* s = rel.as<SplitNode>()) {
+ if (!state.count(s->outer) || !state.count(s->inner)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ PrimExpr outer = state.at(s->outer);
+ PrimExpr inner = state.at(s->inner);
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr parent_min = dom_map.at(s->parent)->min;
+ state[s->parent] = inner + outer * factor;
+ // add min if they exist
+ if (!is_zero(parent_min)) {
+ state[s->parent] = state[s->parent] + parent_min;
+ }
+ } else if (const FuseNode* s = rel.as<FuseNode>()) {
+ if (!state.count(s->fused)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ PrimExpr value = state.at(s->fused);
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr outer_min = dom_map.at(s->outer)->min;
+ PrimExpr inner_min = dom_map.at(s->inner)->min;
+ state[s->outer] = indexdiv(value, factor);
+ state[s->inner] = indexmod(value, factor);
+ // add min if they exist
+ if (!is_zero(outer_min)) {
+ state[s->outer] = state[s->outer] + outer_min;
+ }
+ if (!is_zero(inner_min)) {
+ state[s->inner] = state[s->inner] + inner_min;
+ }
+ } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+ if (!state.count(s->rebased)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ PrimExpr value = state.at(s->rebased);
+ PrimExpr parent_min = dom_map.at(s->parent)->min;
+ // add min if they exist
+ if (!is_zero(parent_min)) {
+ state[s->parent] = value + parent_min;
+ } else {
+ state[s->parent] = value;
+ }
+ } else if (rel.as<SingletonNode>()) {
+ } else {
+ LOG(FATAL) << "unknown relation type";
+ }
+ }
+}
+
+void PassDownIndex(const Stage& stage,
+ const Map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, PrimExpr>* p_state,
+ bool allow_missing) {
+ auto& state = *p_state;
+ for (IterVarRelation rel : stage->relations) {
+ if (const SplitNode* s = rel.as<SplitNode>()) {
+ if (!state.count(s->parent)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ Range r = dom_map.at(s->inner);
+ CHECK(is_zero(r->min));
+ PrimExpr parent = state.at(s->parent);
+ PrimExpr factor = r->extent;
+ state[s->outer] = indexdiv(parent, factor);
+ state[s->inner] = indexmod(parent, factor);
+ } else if (const FuseNode* s = rel.as<FuseNode>()) {
+ if (!state.count(s->inner) && !state.count(s->outer)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr outer_min = dom_map.at(s->outer)->min;
+ PrimExpr inner_min = dom_map.at(s->inner)->min;
+ PrimExpr inner = state.at(s->inner);
+ PrimExpr outer = state.at(s->outer);
+ CHECK(is_zero(outer_min));
+ CHECK(is_zero(inner_min));
+ state[s->fused] = outer * factor + inner;
+ } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+ if (!state.count(s->rebased)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ PrimExpr value = state.at(s->parent);
+ PrimExpr parent_min = dom_map.at(s->parent)->min;
+ CHECK(is_zero(parent_min));
+ state[s->rebased] = value;
+ } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
+ state[s->iter] = make_zero(s->iter->var.dtype());
+ } else {
+ LOG(FATAL) << "unknown relation type";
+ }
+ }
+}
+
+// Domain message passing.
+void PassUpDomain(const SplitNode* s,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const IntSet& outer,
+ const IntSet& inner,
+ IntSet* parent) {
+ if (dom_map.count(s->outer) &&
+ dom_map.count(s->inner) &&
+ dom_map.count(s->parent) &&
+ outer.match_range(dom_map.at(s->outer)) &&
+ inner.match_range(dom_map.at(s->inner))) {
+ *parent = IntSet::range(dom_map.at(s->parent));
+ return;
+ }
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr parent_min = dom_map.at(s->parent)->min;
+ CHECK(outer.defined());
+ CHECK(inner.defined());
+ CHECK(factor.defined());
+ *parent = arith::EvalSet(
+ s->outer->var * factor + s->inner->var + parent_min,
+ {{s->outer, outer}, {s->inner, inner}});
+}
+
+void PassUpDomain(const FuseNode* s,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const IntSet& fused,
+ IntSet* outer,
+ IntSet* inner) {
+ CHECK(dom_map.count(s->outer));
+ CHECK(dom_map.count(s->inner));
+ CHECK(dom_map.count(s->fused));
+
+ if (fused.match_range(dom_map.at(s->fused))) {
+ *outer = IntSet::range(dom_map.at(s->outer));
+ *inner = IntSet::range(dom_map.at(s->inner));
+ return;
+ }
+ PrimExpr outer_min = dom_map.at(s->outer)->min;
+ PrimExpr inner_min = dom_map.at(s->inner)->min;
+
+ if (fused.is_single_point()) {
+ PrimExpr value = fused.point_value();
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr v_outer = indexdiv(value, factor);
+ PrimExpr v_inner = indexmod(value, factor);
+ if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
+ if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
+ *outer = IntSet::single_point(v_outer);
+ *inner = IntSet::single_point(v_inner);
+ } else {
+ PrimExpr fused_extent = (fused.max() - fused.min() + 1);
+ PrimExpr inner_extent = dom_map.at(s->inner)->extent;
+ *outer = IntSet::interval(
+ outer_min + indexdiv(fused.min(), inner_extent),
+ outer_min + indexdiv(fused.max(), inner_extent));
+ if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) &&
+ is_zero(Simplify(indexmod(fused.min(), fused_extent)))) {
+ // fused never spans multiple rows, make a tight bounding box
+ // there may be other cases when bounding box could be tightened
+ *inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent),
+ inner_min + indexmod(fused.max(), inner_extent));
+ } else { // fused may span multiple rows, use full row widths
+ if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) ||
+ !is_zero(Simplify(indexmod(fused.min(), inner_extent)))) {
+ LOG(WARNING) <<
+ "fused and original axes are not aligned, this may cause redundant computations";
+ }
+ *inner = IntSet::range(dom_map.at(s->inner));
+ }
+ return;
+ }
+}
+
+void PassUpDomain(const RebaseNode* s,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ const IntSet& rebased,
+ IntSet* parent) {
+ CHECK(dom_map.count(s->parent));
+ if (rebased.match_range(dom_map.at(s->rebased))) {
+ *parent = IntSet::range(dom_map.at(s->parent));
+ return;
+ }
+ PrimExpr parent_min = dom_map.at(s->parent)->min;
+ *parent = arith::EvalSet(s->rebased->var + parent_min,
+ {{s->rebased, rebased}});
+}
+
+void PassUpDomain(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, IntSet>* p_state) {
+ auto& state = *p_state;
+ for (size_t i = stage->relations.size(); i != 0; --i) {
+ IterVarRelation rel = stage->relations[i - 1];
+ if (const SplitNode* r = rel.as<SplitNode>()) {
+ IntSet parent;
+ PassUpDomain(r, dom_map,
+ state.at(r->outer), state.at(r->inner),
+ &parent);
+ state[r->parent] = parent;
+ } else if (const FuseNode* r = rel.as<FuseNode>()) {
+ IntSet outer, inner;
+ PassUpDomain(r, dom_map,
+ state.at(r->fused),
+ &outer, &inner);
+ state[r->outer] = outer;
+ state[r->inner] = inner;
+ } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
+ IntSet parent;
+ PassUpDomain(r, dom_map,
+ state.at(r->rebased),
+ &parent);
+ state[r->parent] = parent;
+ } else if (rel.as<SingletonNode>()) {
+ } else {
+ LOG(FATAL) << "unknown relation type";
+ }
+ }
+}
+
+// Pass up bit mask with or relation.
+void PassUpBitMaskOr(const Stage& stage,
+ std::unordered_map<IterVar, int>* p_state,
+ bool allow_missing) {
+ auto& state = *p_state;
+ for (size_t i = stage->relations.size(); i != 0; --i) {
+ IterVarRelation rel = stage->relations[i - 1];
+ if (const SplitNode* s = rel.as<SplitNode>()) {
+ if (!state.count(s->inner) && !state.count(s->outer)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ int res = 0;
+ if (!state.count(s->parent)) res |= state[s->parent];
+ if (!state.count(s->inner)) res |= state[s->inner];
+ if (!state.count(s->outer)) res |= state[s->outer];
+ state[s->parent] = res;
+ } else if (const FuseNode* s = rel.as<FuseNode>()) {
+ if (!state.count(s->fused)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ if (!state.count(s->outer)) {
+ state[s->outer] = state[s->fused];
+ } else {
+ state[s->outer] |= state[s->fused];
+ }
+ if (!state.count(s->inner)) {
+ state[s->inner] = state[s->fused];
+ } else {
+ state[s->inner] |= state[s->fused];
+ }
+ } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+ if (!state.count(s->rebased)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ if (!state.count(s->parent)) {
+ state[s->parent] = state[s->rebased];
+ } else {
+ state[s->parent] |= state[s->rebased];
+ }
+ } else if (rel.as<SingletonNode>()) {
+ } else {
+ LOG(FATAL) << "unknown relation type";
+ }
+ }
+}
+
+void PassDownBitMaskOr(const Stage& stage,
+ std::unordered_map<IterVar, int>* p_state,
+ bool allow_missing) {
+ auto& state = *p_state;
+ for (IterVarRelation rel : stage->relations) {
+ if (const SplitNode* s = rel.as<SplitNode>()) {
+ if (!state.count(s->parent)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ if (!state.count(s->outer)) {
+ state[s->outer] = state.at(s->parent);
+ } else {
+ state[s->outer] |= state.at(s->parent);
+ }
+ if (!state.count(s->inner)) {
+ state[s->inner] = state.at(s->parent);
+ } else {
+ state[s->inner] |= state.at(s->parent);
+ }
+ } else if (const FuseNode* s = rel.as<FuseNode>()) {
+ if (!state.count(s->outer) && !state.count(s->inner)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ int res = 0;
+ if (state.count(s->outer)) res |= state.at(s->outer);
+ if (state.count(s->inner)) res |= state.at(s->inner);
+ if (state.count(s->fused)) res |= state.at(s->fused);
+ state[s->fused] = res;
+ } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+ if (!state.count(s->parent)) {
+ CHECK(allow_missing);
+ continue;
+ }
+ if (!state.count(s->rebased)) {
+ state[s->rebased] = state.at(s->parent);
+ } else {
+ state[s->rebased] |= state.at(s->parent);
+ }
+ } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
+ state[s->iter] = 0;
+ } else {
+ LOG(FATAL) << "unknown relation type";
+ }
+ }
+}
+
+
+/*!
+ * \brief message passing to find if boundary checking on IterVar is needed.
+ * \param s The stage to be used.
+ * \param p_state The message passing state
+ * IterVar->flag
+ */
+void PassUpBoundCheck(const Stage& s,
+ const Map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, bool>* p_state,
+ arith::Analyzer* analyzer) {
+ auto& state = *p_state;
+ for (size_t i = s->relations.size(); i != 0; --i) {
+ IterVarRelation rel = s->relations[i - 1];
+ if (const SplitNode* s = rel.as<SplitNode>()) {
+ bool outer = state.at(s->outer);
+ bool inner = state.at(s->inner);
+
+ if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr step = dom_map.at(s->outer)->extent;
+ if (outer || inner) {
+ state[s->parent] = true;
+ } else {
+ if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
+ state[s->parent] = false;
+ } else {
+ state[s->parent] = true;
+ }
+ }
+ } else {
+ state[s->parent] = true;
+ }
+ } else if (const FuseNode* s = rel.as<FuseNode>()) {
+ bool fused = state.at(s->fused);
+ state[s->outer] = fused;
+ state[s->inner] = fused;
+ } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+ state[s->parent] = state.at(s->rebased);
+ } else if (rel.as<SingletonNode>()) {
+ // nop
+ } else {
+ LOG(FATAL) << "unknown relation type";
+ }
+ }
+}
+
+std::vector<PrimExpr> MakeBoundCheck(
+ const Stage& stage,
+ const Map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, PrimExpr>& value_map,
+ bool skip_ivar_domain,
+ const std::unordered_set<IterVar>& skip_iter) {
+ arith::Analyzer analyzer;
+
+ std::unordered_map<IterVar, bool> bound_state;
+ for (IterVar iv : stage->leaf_iter_vars) {
+ bound_state[iv] = false;
+ }
+ PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
+
+ std::vector<PrimExpr> preds;
+ std::unordered_map<const VarNode*, IntSet> iset_dmap;
+
+ // setup domain map for set analysis
+ for (const auto& kv : dom_map) {
+ iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
+ }
+
+ for (const IterVar& iv : stage->all_iter_vars) {
+ if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
+ if (bound_state.at(iv)) {
+ Range dom = dom_map.at(iv);
+ PrimExpr value = value_map.at(iv) - dom->min;
+ PrimExpr vmax = EvalSet(value, iset_dmap).max();
+ if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
+ preds.emplace_back(value < dom->extent);
+ }
+ }
+ }
+ for (const IterVar& iv : stage->op->root_iter_vars()) {
+ if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
+ Range dom = dom_map.at(iv);
+ CHECK(iv->dom.defined());
+ if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
+ PrimExpr value = value_map.at(iv) - iv->dom->min;
+ IntSet s = EvalSet(value, iset_dmap);
+ PrimExpr vmin = s.min();
+ PrimExpr vmax = s.max();
+ // The range of `value` resides in [vmin, vmax]
+ if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
+ preds.emplace_back(value >= 0);
+ }
+ if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < iv->dom->extent)) {
+ preds.emplace_back(value < iv->dom->extent);
+ }
+ }
+ }
+ return preds;
+}
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file message_passing.h
+ * \brief Common utilities to do message passing
+ * on the schedule hyper graph.
+ */
+#ifndef TVM_TOP_SCHEDULE_MESSAGE_PASSING_H_
+#define TVM_TOP_SCHEDULE_MESSAGE_PASSING_H_
+
+#include <tvm/expr.h>
+#include <tvm/top/schedule.h>
+#include <tvm/top/operation.h>
+#include <tvm/arith/analyzer.h>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace top {
+/*!
+ * \brief Downward inference of domain of each IterVar.
+ * Caller set the range of the root, then the function
+ * propagates it towards the leaves.
+ *
+ * \param stage The stage to operate on.
+ * \param p_state The state of the message passing.
+ * \param analyzer Analyzer context, storing information about bounds in p_state.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassDownDomain(
+ const Stage& stage,
+ std::unordered_map<IterVar, Range>* p_state,
+ arith::Analyzer* analyzer,
+ bool allow_missing = false);
+
+/*!
+ * \param Upward inference of index of each IterVar.
+ * given index assignement of the leaves,
+ *
+ * \param stage The stage to operate on.
+ * \param dom_map The domain map of each iteration variable's domain.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassUpIndex(const Stage& stage,
+ const Map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, PrimExpr>* p_state,
+ bool allow_missing = false);
+
+/*!
+ * \param Downward inference of index of each IterVar.
+ * given index assignement of roots.
+ *
+ * \param stage The stage to operate on.
+ * \param dom_map The domain map of each iteration variable's domain.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassDownIndex(const Stage& stage,
+ const Map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, PrimExpr>* p_state,
+ bool allow_missing = false);
+
+/*!
+ * \param Upward inference of domain set of each IterVar.
+ * given domain assignment of the leaves,
+ *
+ * \param stage The stage to operate on.
+ * \param dom_map The domain map of each iteration variable's maximum domain.
+ * \param p_state The index state of each IterVar.
+ */
+void PassUpDomain(const Stage& stage,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ std::unordered_map<IterVar, IntSet>* p_state);
+
+/*!
+ * \brief Upward message passing of bitmask with or relation.
+ * \param stage The stage to operate on.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassUpBitMaskOr(const Stage& stage,
+ std::unordered_map<IterVar, int>* p_state,
+ bool allow_missing = false);
+
+/*!
+ * \brief Downward message passing of bitmask with or relation.
+ * \param stage The stage to operate on.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassDownBitMaskOr(const Stage& stage,
+ std::unordered_map<IterVar, int>* p_state,
+ bool allow_missing = false);
+
+/*!
+ * \brief Create boundary check predicates given remapped value of root
+ * \param stage The stage we operate on
+ * \param dom_map The domain map of each value.
+ * \param value_map The value map of the root iter var.
+ * \param skip_ivar_domain Whether we skip check for IterVar's original domain.
+ * \param skip_iter The set of variables to skip bound condition.
+ * \return List of predicates that we need to check.
+ */
+std::vector<PrimExpr>
+MakeBoundCheck(
+ const Stage& stage,
+ const Map<IterVar, Range>& dom_map,
+ const std::unordered_map<IterVar, PrimExpr>& value_map,
+ bool skip_ivar_domain,
+ const std::unordered_set<IterVar>& skip_iter);
+
+} // namespace top
+} // namespace tvm
+#endif // TVM_TOP_SCHEDULE_MESSAGE_PASSING_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file schedule_dataflow_rewrite.cc
+ */
+#include <tvm/top/schedule.h>
+#include <tvm/top/operation.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/ir_pass.h>
+#include <unordered_set>
+#include "message_passing.h"
+#include "../../pass/ir_util.h"
+#include "../../arith/compute_expr.h"
+
+namespace tvm {
+namespace top {
+// find first occurance location in leaf
+template<typename T>
+size_t FindNodeRef(ArrayNode* array_node, const T& v) {
+ const Object* n = v.get();
+ for (size_t i = 0; i < array_node->data.size(); ++i) {
+ if (array_node->data[i].get() == n) return i;
+ }
+ return array_node->data.size();
+}
+
+// The replacer of cache.
+class VarReplacer : public ir::StmtExprMutator {
+ public:
+ explicit VarReplacer(
+ const std::unordered_map<const VarNode*, PrimExpr>& vsub)
+ : vsub_(vsub) {}
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ auto it = vsub_.find(op);
+ if (it != vsub_.end()) return it->second;
+ return GetRef<PrimExpr>(op);
+ }
+
+ ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
+ // Replace free variables in combiner
+ auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) {
+ return this->VisitExpr(e);
+ });
+ auto new_result = ir::UpdateArray(combiner->result, [this] (const PrimExpr& e) {
+ return this->VisitExpr(e);
+ });
+
+ if (combiner->identity_element.same_as(new_identity) &&
+ combiner->identity_element.same_as(new_result)) {
+ return combiner;
+ } else {
+ return ir::CommReducerNode::make(
+ combiner->lhs, combiner->rhs, new_result, new_identity);
+ }
+ }
+
+ PrimExpr VisitExpr_(const ir::ReduceNode* op) final {
+ PrimExpr new_e = StmtExprMutator::VisitExpr_(op);
+ const ir::ReduceNode* new_reduce = new_e.as<ir::ReduceNode>();
+ ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
+ if (op->combiner.same_as(new_combiner)) {
+ return new_e;
+ } else {
+ return ir::ReduceNode::make(
+ new_combiner,
+ new_reduce->source,
+ new_reduce->axis,
+ new_reduce->condition,
+ new_reduce->value_index);
+ }
+ }
+
+ private:
+ const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
+};
+
+PrimExpr InjectPredicate(const Array<PrimExpr>& predicates,
+ PrimExpr body) {
+ using ir::ReduceNode;
+ using ir::SelectNode;
+ if (predicates.size() == 0) return body;
+ const ReduceNode* reduce = body.as<ReduceNode>();
+ if (reduce) {
+ auto n = make_object<ReduceNode>(*reduce);
+ n->condition = n->condition && arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr());
+ return PrimExpr(n);
+ }
+ return SelectNode::make(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()),
+ body,
+ make_zero(body.dtype()));
+}
+
+// Replace data flow appears in all stages given the tensor change.
+// Also update vmap if subsequent dataflow need to be replaced.
+// Need to keep an update to the date transitive closure property on the vmap by a reverse map.
+void ReplaceDataFlow(const Array<Stage>& stages,
+ std::unordered_map<Tensor, Tensor>* vmap,
+ std::unordered_map<Tensor, Tensor>* rvmap) {
+ for (Stage s : stages) {
+ Operation op = s->op->ReplaceInputs(s->op, *vmap);
+ if (!op.same_as(s->op)) {
+ for (int i = 0; i < op->num_outputs(); ++i) {
+ auto it = rvmap->find(s->op.output(i));
+ if (it != rvmap->end()) {
+ (*vmap)[it->second] = op.output(i);
+ } else {
+ (*vmap)[s->op.output(i)] = op.output(i);
+ (*rvmap)[op.output(i)] = s->op.output(i);
+ }
+ }
+ s->op = op;
+ }
+ }
+}
+
+inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
+ return (a->combiner.same_as(b->combiner)) &&
+ (a->source.same_as(b->source)) &&
+ (a->axis.same_as(b->axis)) &&
+ (a->condition.same_as(b->condition));
+}
+
+Tensor Schedule::cache_read(const Tensor& tensor,
+ const std::string& scope,
+ const Array<Operation>& readers) {
+ (*this)->InvalidateCache();
+ // create identity mapping.
+ std::ostringstream os;
+ os << tensor->op->name;
+ if (tensor->op->num_outputs() != 1) {
+ os << ".v" << tensor->value_index;
+ }
+ os << "." << scope;
+
+ std::unordered_map<Tensor, Tensor> vsub;
+ Stage s = operator[](tensor->op);
+ Tensor sugar_tensor = s->op.output(tensor->value_index);
+ Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
+ return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
+ }, os.str());
+ vsub[sugar_tensor] = cache;
+
+ std::unordered_map<Tensor, Tensor> vmap;
+ std::unordered_map<Tensor, Tensor> rvmap;
+ for (Operation op : readers) {
+ Stage s = operator[](op);
+ Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
+ CHECK(!repl_op.same_as(s->op))
+ << "Cannot find " << tensor
+ << " in the inputs of " << s->op;
+ vmap[s->op.output(0)] = repl_op.output(0);
+ rvmap[repl_op.output(0)] = s->op.output(0);
+ s->op = repl_op;
+ }
+ ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
+ ArrayNode* stages = (*this)->stages.CopyOnWrite();
+ Stage op_stage = operator[](tensor->op);
+ size_t pos = FindNodeRef(stages, op_stage);
+ Stage cache_stage = Stage(cache->op);
+ cache_stage.set_scope(scope);
+ CHECK_LT(pos, stages->data.size());
+ stages->data.insert(stages->data.begin() + pos + 1,
+ cache_stage);
+ (*this)->stage_map.Set(cache->op, cache_stage);
+ // Update group
+ cache_stage->group = op_stage->group;
+ if (cache_stage->group.defined()) {
+ ++cache_stage->group->num_child_stages;
+ }
+ return cache;
+}
+
+template<typename OpType>
+void PrepareAxisMapping(Stage orig_stage,
+ OpType* op,
+ std::unordered_set<IterVar>* p_red_axis,
+ Array<IterVar>* p_new_axis,
+ std::unordered_map<IterVar, Range>* p_dom_map,
+ std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
+ std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
+ std::vector<PrimExpr>* p_predicates) {
+ auto& red_axis = *p_red_axis;
+ auto& new_axis = *p_new_axis;
+ auto& dom_map = *p_dom_map;
+ auto& vsub = *p_vsub;
+ auto& vsub2newvar = *p_vsub2newvar;
+ auto& predicates = *p_predicates;
+ arith::Analyzer analyzer;
+
+ for (IterVar iv : op->reduce_axis) {
+ red_axis.insert(iv);
+ }
+ for (IterVar iv : op->axis) {
+ dom_map[iv] = iv->dom;
+ analyzer.Bind(iv->var, iv->dom);
+ }
+ top::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
+ {
+ // The source->cache
+ std::unordered_map<IterVar, PrimExpr> value_map;
+ for (IterVar iv : orig_stage->leaf_iter_vars) {
+ if (red_axis.count(iv)) continue;
+ CHECK_EQ(iv->iter_type, kDataPar)
+ << "Can only relayout with in data parallel dimensions";
+ Range dom = dom_map.at(iv);
+ IterVar new_iv = IterVarNode::make(
+ dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
+ new_axis.push_back(new_iv);
+ if (is_one(dom->min)) {
+ value_map[iv] = dom->min;
+ } else {
+ value_map[iv] = iv->var;
+ vsub2newvar[iv->var.get()] = new_iv->var;
+ }
+ }
+ // skip reduction iteration.
+ std::unordered_set<IterVar> skip_bound_check;
+ for (IterVar iv : op->reduce_axis) {
+ skip_bound_check.insert(iv);
+ }
+ PassUpIndex(orig_stage, dom_map, &value_map, true);
+ predicates = MakeBoundCheck(
+ orig_stage, dom_map, value_map, true, skip_bound_check);
+ // The root axis
+ for (IterVar iv : op->axis) {
+ if (value_map.count(iv)) {
+ vsub[iv->var.get()] = value_map.at(iv);
+ } // to handle tensor axis
+ }
+ }
+}
+
+Array<Tensor> ReplaceOriginalOp(Schedule sch,
+ Stage orig_stage,
+ const std::string& scope,
+ Operation cache_op,
+ Operation orig_new_op,
+ size_t tensor_size) {
+ Array<Tensor> cache_tensor_list;
+ for (size_t i = 0; i < tensor_size; i++) {
+ Tensor cache_tensor = cache_op.output(i);
+ cache_tensor_list.push_back(cache_tensor);
+ }
+ // The replace of the dataflow
+ std::unordered_map<Tensor, Tensor> vmap;
+ std::unordered_map<Tensor, Tensor> rvmap;
+ vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
+ rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
+ for (size_t i = 0; i < tensor_size; i++) {
+ vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
+ rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
+ }
+ ReplaceDataFlow(sch->stages, &vmap, &rvmap);
+ // mutate orig stage
+ orig_stage->op = orig_new_op;
+ orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
+ orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
+ orig_stage->relations = Array<IterVarRelation>();
+ // create schedule for new cached stage.
+ ArrayNode* stages = sch->stages.CopyOnWrite();
+ size_t pos = FindNodeRef(stages, orig_stage);
+ Stage cache_stage = Stage(cache_op);
+ cache_stage.set_scope(scope);
+ CHECK_LT(pos, stages->data.size());
+ stages->data.insert(stages->data.begin() + pos,
+ cache_stage);
+ sch->stage_map.Set(cache_op, cache_stage);
+ // Update group
+ cache_stage->group = orig_stage->group;
+ if (cache_stage->group.defined()) {
+ ++cache_stage->group->num_child_stages;
+ }
+ return cache_tensor_list;
+}
+
+
+// Cache write and relayout the data according to loop pattern
+Array<Tensor> CacheWriteWithReLayout(Schedule sch,
+ const Array<Tensor>& tensor_array,
+ const std::string& scope) {
+ size_t tensor_size = tensor_array.size();
+ sch->InvalidateCache();
+ Tensor tensor = tensor_array[0];
+ Stage orig_stage = sch[tensor->op];
+ const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
+
+ std::unordered_set<IterVar> red_axis;
+ Array<IterVar> new_axis;
+ std::unordered_map<IterVar, Range> dom_map;
+
+ std::unordered_map<const VarNode*, PrimExpr> vsub;
+ std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
+ std::vector<PrimExpr> predicates;
+
+ PrepareAxisMapping(orig_stage, compute,
+ &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
+
+ PrimExpr body;
+ Array<PrimExpr> body_list;
+ const ir::ReduceNode* first_reduce = nullptr;
+ for (auto cbody : compute->body) {
+ body = VarReplacer(vsub)(cbody);
+ body = InjectPredicate(predicates, body);
+ body = VarReplacer(vsub2newvar)(body);
+ // Reduce nodes in ONE computeOp must be the same except value_index
+ // This is right only if the original body ensures Reduce nodes are the same
+ if (body->IsInstance<ir::ReduceNode>()) {
+ const ir::ReduceNode* reduce_body = body.as<ir::ReduceNode>();
+ if (first_reduce != nullptr) {
+ CHECK(ReduceEqual(reduce_body, first_reduce));
+ body = ir::ReduceNode::make(first_reduce->combiner,
+ first_reduce->source,
+ first_reduce->axis,
+ first_reduce->condition,
+ reduce_body->value_index);
+ } else {
+ first_reduce = reduce_body;
+ }
+ } else {
+ CHECK(first_reduce == nullptr)
+ << "cannot mix reduce and other node in ONE compute bodys";
+ }
+ body_list.push_back(body);
+ }
+ // The reader args
+ Array<PrimExpr> args;
+ {
+ // cache->compute
+ std::unordered_map<IterVar, PrimExpr> value_map;
+ for (IterVar iv : compute->axis) {
+ value_map[iv] = iv->var;
+ }
+ top::PassDownIndex(orig_stage, dom_map, &value_map, true);
+ for (IterVar iv : orig_stage->leaf_iter_vars) {
+ if (red_axis.count(iv)) continue;
+ args.push_back(value_map.at(iv));
+ }
+ }
+ Operation cache_op = ComputeOpNode::make(
+ compute->name + "." + scope, compute->tag, compute->attrs,
+ new_axis, body_list);
+
+ Array<PrimExpr> cache_expr_list;
+ for (size_t i = 0; i < tensor_size; i++) {
+ Tensor cache_tensor = cache_op.output(i);
+ cache_expr_list.push_back(cache_tensor(args));
+ }
+ Operation orig_new_op = ComputeOpNode::make(
+ compute->name, compute->tag, compute->attrs,
+ compute->axis, cache_expr_list);
+ return ReplaceOriginalOp(sch, orig_stage, scope,
+ cache_op, orig_new_op, tensor_size);
+}
+
+
+// for tensor compute op
+Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
+ const Array<Tensor>& tensor_array,
+ const std::string& scope) {
+ size_t tensor_size = tensor_array.size();
+ sch->InvalidateCache();
+ Tensor tensor = tensor_array[0];
+ Stage orig_stage = sch[tensor->op];
+ const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
+ CHECK_EQ(tensor_op->num_outputs(), 1)
+ << "cache write only support single output tensor_compute_op";
+
+ std::unordered_set<IterVar> red_axis;
+ Array<IterVar> new_axis;
+ std::unordered_map<IterVar, Range> dom_map;
+
+ std::unordered_map<const VarNode*, PrimExpr> vsub;
+ std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
+ std::vector<PrimExpr> predicates;
+
+ PrepareAxisMapping(orig_stage, tensor_op,
+ &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
+
+
+ for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
+ IterVar iv = tensor_op->axis[i];
+ IterVar new_iv = IterVarNode::make(
+ iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
+ new_axis.push_back(new_iv);
+ }
+ Array<Region> new_regions;
+ for (Region old_region : tensor_op->input_regions) {
+ Region region;
+ for (Range r : old_region) {
+ PrimExpr min = VarReplacer(vsub2newvar)(r->min);
+ PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
+ region.push_back(Range::make_by_min_extent(min, extent));
+ }
+ new_regions.push_back(region);
+ }
+
+ Array<PrimExpr> new_scalar_inputs;
+ for (PrimExpr old_input : tensor_op->scalar_inputs) {
+ new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
+ }
+
+ Operation cache_op = TensorComputeOpNode::make(
+ tensor_op->name + "." + scope, tensor_op->tag, new_axis,
+ tensor_op->reduce_axis, tensor_op->schedulable_ndim,
+ tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);
+
+ // axis will be used in generating compute op
+ Array<IterVar> compute_axis = tensor_op->axis;
+ for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
+ IterVar iv = tensor_op->axis[i];
+ IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
+ compute_axis.Set(i, aiv);
+ }
+
+ // The reader args
+ Array<PrimExpr> args;
+ {
+ // cache->compute
+ std::unordered_map<IterVar, PrimExpr> value_map;
+ for (IterVar iv : compute_axis) {
+ value_map[iv] = iv->var;
+ }
+ PassDownIndex(orig_stage, dom_map, &value_map, true);
+ for (IterVar iv : orig_stage->leaf_iter_vars) {
+ if (red_axis.count(iv)) continue;
+ args.push_back(value_map.at(iv));
+ }
+ // tensorized region axis
+ for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
+ IterVar iv = compute_axis[i];
+ args.push_back(value_map.at(iv));
+ }
+ }
+
+ Array<PrimExpr> cache_expr_list;
+ for (size_t i = 0; i < tensor_size; i++) {
+ Tensor cache_tensor = cache_op.output(i);
+ cache_expr_list.push_back(cache_tensor(args));
+ }
+ Operation orig_new_op = ComputeOpNode::make(
+ tensor_op->name, tensor_op->tag, {},
+ compute_axis, cache_expr_list);
+ return ReplaceOriginalOp(sch, orig_stage, scope,
+ cache_op, orig_new_op, tensor_size);
+}
+
+
+Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
+ const std::string& scope) {
+ (*this)->InvalidateCache();
+ CHECK(tensor_array.size() > 0)
+ << "size of tensor_array must be greater than 0";
+ Tensor tensor = tensor_array[0];
+ Stage orig_stage = operator[](tensor->op);
+ const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
+ CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
+ << "size of input tensor list must be same as number of stage outputs";
+ for (size_t i = 1; i < tensor_array.size(); i++) {
+ Stage tmp_stage = operator[](tensor_array[i]->op);
+ CHECK(orig_stage.same_as(tmp_stage))
+ << "Input tensor list must be generated by ONE computeOp";
+ }
+ return CacheWriteWithReLayout(*this, tensor_array, scope);
+}
+
+
+Tensor Schedule::cache_write(const Tensor& tensor,
+ const std::string& scope) {
+ // support original compute and tensor compute both
+ (*this)->InvalidateCache();
+ if (tensor->op.as<ComputeOpNode>()) {
+ return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
+ } else if (tensor->op.as<TensorComputeOpNode>()) {
+ return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
+ } else {
+ LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
+ return Tensor();
+ }
+}
+
+
+void RebaseNonZeroMinLoop(const Schedule& sch) {
+ std::unordered_map<IterVar, IterVar> rebase_map;
+ for (Stage s : sch->stages) {
+ if (s->attach_type == kInlinedAlready) continue;
+
+ auto root_iter_vars = s->op->root_iter_vars();
+ ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
+ for (IterVar iv : root_iter_vars) {
+ size_t idx = FindNodeRef(leaf_vars, iv);
+ auto it = s->iter_var_attrs.find(iv);
+ // don;t need to rebase path that are binded.
+ if (it != s->iter_var_attrs.end() &&
+ (*it).second->bind_thread.defined()) {
+ continue;
+ }
+ if (idx < leaf_vars->data.size()) {
+ // insert rebase
+ IterVar rebased = IterVarNode::make(
+ Range(), iv->var.copy_with_suffix(""), iv->iter_type);
+ s->relations.push_back(RebaseNode::make(iv, rebased));
+ if (s->iter_var_attrs.count(iv)) {
+ s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
+ }
+ leaf_vars->data[idx] = rebased;
+ rebase_map[iv] = rebased;
+ }
+ }
+ }
+ // remap the parent relation
+ for (Stage s : sch->stages) {
+ if (s->attach_type != kScope) continue;
+ if (rebase_map.count(s->attach_ivar)) {
+ s->attach_ivar = rebase_map.at(s->attach_ivar);
+ }
+ }
+ for (Stage s : sch->groups) {
+ if (s->attach_type != kScope) continue;
+ if (rebase_map.count(s->attach_ivar)) {
+ s->attach_ivar = rebase_map.at(s->attach_ivar);
+ }
+ }
+}
+
+void InjectInline(ScheduleNode* sch) {
+ sch->InvalidateCache();
+
+ std::vector<Array<PrimExpr> > new_body(sch->stages.size());
+ std::vector<bool> changed(sch->stages.size(), false);
+ std::vector<Stmt> new_hybrid_body(sch->stages.size());
+ std::vector<bool> hybrid_changed(sch->stages.size(), false);
+ // inline all the ops
+ for (size_t i = sch->stages.size(); i != 0; --i) {
+ Stage stage = sch->stages[i - 1];
+ if (stage->attach_type == kInline) {
+ stage->attach_type = kInlinedAlready;
+ Array<Var> args;
+ PrimExpr body;
+ {
+ // setup args
+ const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
+ CHECK(compute)
+ << "can only inline compute op";
+ for (auto iv : compute->axis) {
+ args.push_back(iv->var);
+ }
+ CHECK_EQ(compute->body.size(), 1U)
+ << "can only inline compute op with 1 output";
+ body = compute->body[0];
+ }
+ for (size_t j = i; j < sch->stages.size(); ++j) {
+ Stage s = sch->stages[j];
+ const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
+ const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
+ if (compute) {
+ if (!new_body[j].size()) {
+ new_body[j] = compute->body;
+ }
+ if (new_body[j][0]->IsInstance<ir::ReduceNode>()) {
+ // specially handle reduction inline for multiplre reductions.
+ const ir::ReduceNode* reduce = new_body[j][0].as<ir::ReduceNode>();
+ for (size_t k = 1; k < new_body[j].size(); ++k) {
+ const ir::ReduceNode* reduce_ = new_body[j][k].as<ir::ReduceNode>();
+ CHECK(reduce_);
+ CHECK(ReduceEqual(reduce_, reduce))
+ << "The Reduce inputs of ComputeOp should "
+ << "have the same attribute except value_index";
+ }
+ PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]),
+ stage->op, args, body).as<ir::EvaluateNode>()->value;
+ if (!new_value.same_as(new_body[j][0])) {
+ changed[j] = true;
+ const ir::ReduceNode* r = new_value.as<ir::ReduceNode>();
+ CHECK_EQ(new_body[j].size(), r->source.size());
+ CHECK(r != nullptr);
+ for (size_t k = 0; k < new_body[j].size(); ++k) {
+ auto n = make_object<ir::ReduceNode>(*r);
+ n->value_index = static_cast<int>(k);
+ n->dtype = r->source[k].dtype();
+ new_body[j].Set(k, PrimExpr(n));
+ }
+ }
+ } else {
+ for (size_t k = 0; k < new_body[j].size(); ++k) {
+ PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]),
+ stage->op, args, body).as<ir::EvaluateNode>()->value;
+ if (!new_value.same_as(new_body[j][k])) {
+ new_body[j].Set(k, new_value);
+ changed[j] = true;
+ }
+ }
+ }
+ } else if (hybrid) {
+ if (!new_hybrid_body[j].defined()) {
+ new_hybrid_body[j] = hybrid->body;
+ }
+ Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body);
+ if (!new_stmt.same_as(new_hybrid_body[j])) {
+ new_hybrid_body[j] = new_stmt;
+ hybrid_changed[j] = true;
+ }
+ }
+ }
+ }
+ }
+ std::unordered_map<Tensor, Tensor> repl;
+ // rewrite dataflow
+ for (size_t i = 0; i < sch->stages.size(); ++i) {
+ Stage s = sch->stages[i];
+ if (s->attach_type == kInlinedAlready) continue;
+ if (new_body[i].size()) {
+ // Logics from ReplaceDataFlow
+ const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
+ CHECK(compute);
+ Operation op = s->op;
+ if (changed[i]) {
+ op = ComputeOpNode::make(
+ compute->name, compute->tag, compute->attrs,
+ compute->axis, new_body[i]);
+ }
+ op = op->ReplaceInputs(op, repl);
+ if (!op.same_as(s->op)) {
+ for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
+ repl[s->op.output(idx)] = op.output(idx);
+ }
+ s->op = op;
+ }
+ } else if (hybrid_changed[i]) {
+ const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
+ CHECK(hybrid);
+ Operation op = HybridOpNode::make(
+ hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
+ hybrid->outputs, new_hybrid_body[i]);
+ op = op->ReplaceInputs(op, repl);
+ for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
+ repl[s->op.output(idx)] = op.output(idx);
+ }
+ s->op = op;
+ } else {
+ Operation op = s->op->ReplaceInputs(s->op, repl);
+ if (!op.same_as(s->op)) {
+ for (int j = 0; j < op->num_outputs(); ++j) {
+ repl[s->op.output(j)] = op.output(j);
+ }
+ s->op = op;
+ }
+ }
+ }
+}
+
+Schedule Schedule::normalize() {
+ Schedule sn = copy();
+ InjectInline(sn.operator->());
+ RebaseNonZeroMinLoop(sn);
+ return sn;
+}
+
+// Handle reduction factor.
+Array<Tensor> Schedule::rfactor(const Tensor& tensor,
+ const IterVar& axis,
+ int factor_axis) {
+ (*this)->InvalidateCache();
+ using ir::ReduceNode;
+ CHECK_EQ(axis->iter_type, kCommReduce)
+ << "Can only factor reduction axis";
+ Stage reduce_stage = operator[](tensor->op);
+ const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
+ CHECK(compute_op) << "Can only factor ComputeOp";
+ ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
+ {
+ size_t axis_pos = FindNodeRef(leaf_vars, axis);
+ CHECK_NE(axis_pos, leaf_vars->data.size())
+ << "Cannot find IterVar " << axis << " in leaf iter vars";
+ }
+ // Find touched reduction axis.
+ std::unordered_map<IterVar, int> touch_map;
+ touch_map[axis] = 1;
+ top::PassUpBitMaskOr(reduce_stage, &touch_map, true);
+ top::PassDownBitMaskOr(reduce_stage, &touch_map, true);
+ // skip reduction iteration.
+ std::unordered_set<IterVar> skip_bound_check;
+ // Verify normal axis are not touched.
+ for (IterVar iv : compute_op->axis) {
+ CHECK(!touch_map.count(iv))
+ << "Factor axis touches normal axis.";
+ skip_bound_check.insert(iv);
+ }
+ // get analyzer.
+ arith::Analyzer analyzer;
+ // Get the replace index
+ std::unordered_map<IterVar, Range> dom_map;
+ std::unordered_map<IterVar, PrimExpr> value_map;
+ for (IterVar iv : compute_op->reduce_axis) {
+ if (touch_map.count(iv)) {
+ dom_map[iv] = iv->dom;
+ } else {
+ skip_bound_check.insert(iv);
+ }
+ analyzer.Bind(iv->var, iv->dom);
+ }
+ top::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
+ for (IterVar iv : reduce_stage->leaf_iter_vars) {
+ if (touch_map.count(iv)) {
+ Range dom = dom_map.at(iv);
+ if (is_one(dom->extent)) {
+ value_map[iv] = dom->min;
+ } else {
+ value_map[iv] = iv->var;
+ }
+ }
+ }
+ top::PassUpIndex(reduce_stage, dom_map, &value_map, true);
+ std::vector<PrimExpr> predicates = MakeBoundCheck(
+ reduce_stage, dom_map, value_map, true, skip_bound_check);
+
+ // Get the factored op node.
+ const int factor_axis_pos = \
+ factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
+ CHECK_LE(factor_axis_pos, compute_op->axis.size());
+ auto n = make_object<ComputeOpNode>();
+ n->name = compute_op->name + ".rf";
+ {
+ // axis relacement.
+ auto iv_node = make_object<IterVarNode>();
+ iv_node->dom = dom_map.at(axis);
+ CHECK(is_zero(iv_node->dom->min))
+ << "Can only factor reduction domain starting from 0";
+ iv_node->var = axis->var;
+ iv_node->iter_type = kDataPar;
+
+ const int size = compute_op->axis.size();
+ for (int idx = 0; idx < size; ++idx) {
+ if (factor_axis_pos == idx) {
+ n->axis.push_back(IterVar(iv_node));
+ }
+ n->axis.push_back(compute_op->axis[idx]);
+ }
+ if (factor_axis_pos == size) {
+ n->axis.push_back(IterVar(iv_node));
+ }
+ }
+ // predicate generation, copy not touched axis.
+ int idx = tensor->value_index;
+ const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
+ CHECK(reduce) << "Can only rfactor non-inline reductions";
+ predicates.push_back(reduce->condition);
+ PrimExpr predicate = likely(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()));
+
+ std::unordered_map<const VarNode*, PrimExpr> vsub;
+
+ for (IterVar iv : compute_op->reduce_axis) {
+ if (!touch_map.count(iv)) {
+ n->reduce_axis.push_back(iv);
+ } else {
+ CHECK(value_map.count(iv));
+ PrimExpr index = value_map.at(iv);
+ vsub[iv->var.get()] = index;
+ }
+ }
+
+ // Copy touched axis.
+ for (IterVar iv : reduce_stage->leaf_iter_vars) {
+ if (touch_map.count(iv) && !iv.same_as(axis)) {
+ CHECK_EQ(iv->iter_type, kCommReduce);
+ auto ncpy = make_object<IterVarNode>(*iv.operator->());
+ ncpy->dom = dom_map.at(iv);
+ n->reduce_axis.push_back(IterVar(ncpy));
+ }
+ }
+ VarReplacer replacer(vsub);
+ Array<PrimExpr> new_source = ir::UpdateArray(reduce->source,
+ [&replacer] (const PrimExpr& e) { return replacer(e); });
+
+ PrimExpr new_pred = replacer(predicate);
+
+ std::vector<PrimExpr> body;
+ for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
+ body.emplace_back(ReduceNode::make(reduce->combiner,
+ new_source,
+ n->reduce_axis,
+ new_pred,
+ idx));
+ }
+ n->body = Array<PrimExpr>(body);
+ // refresh relations, keep the un-touched relations.
+ Array<IterVarRelation> rels;
+ for (IterVarRelation rel : reduce_stage->relations) {
+ bool touched = false;
+ if (const SplitNode* r = rel.as<SplitNode>()) {
+ if (touch_map.count(r->parent)) touched = true;
+ } else if (const FuseNode* r = rel.as<FuseNode>()) {
+ if (touch_map.count(r->fused)) touched = true;
+ } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
+ if (touch_map.count(r->parent)) touched = true;
+ } else {
+ LOG(FATAL) << "unknown relation type";
+ }
+ if (!touched) {
+ rels.push_back(rel);
+ }
+ }
+ // initialize the factored stage.
+ Operation factor_op(n);
+ ArrayNode* stages = (*this)->stages.CopyOnWrite();
+ size_t stage_pos = FindNodeRef(stages, reduce_stage);
+ Stage factor_stage = Stage(factor_op);
+ factor_stage->relations = rels;
+ CHECK_LT(stage_pos, stages->data.size());
+ stages->data.insert(stages->data.begin() + stage_pos,
+ factor_stage);
+ (*this)->stage_map.Set(factor_op, factor_stage);
+ factor_stage->group = reduce_stage->group;
+ if (factor_stage->group.defined()) {
+ ++factor_stage->group->num_child_stages;
+ }
+ // Replace the old reduction.
+ IterVar repl_red_axis = reduce_axis(
+ dom_map.at(axis), axis->var->name_hint + ".v");
+ Array<Tensor> factor_tensors;
+ Array<Tensor> old_tensors;
+ int size = factor_op->num_outputs();
+ for (int idx = 0; idx < size; ++idx) {
+ factor_tensors.push_back(factor_op.output(idx));
+ old_tensors.push_back(reduce_stage->op.output(idx));
+ }
+ Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
+ [&](const Array<Var>& i) {
+ Array<PrimExpr> indices;
+ const int idx_size = static_cast<int>(i.size());
+ for (int idx = 0; idx < idx_size; ++idx) {
+ if (factor_axis_pos == idx) {
+ indices.push_back(repl_red_axis->var);
+ }
+ indices.push_back(i[idx]);
+ }
+ if (factor_axis_pos == idx_size) {
+ indices.push_back(repl_red_axis->var);
+ }
+ Array<PrimExpr> factor_exprs;
+ for (int idx = 0; idx < size; ++idx) {
+ factor_exprs.push_back(factor_tensors[idx](indices));
+ }
+ Array<PrimExpr> reductions;
+ Array<IterVar> axis = {repl_red_axis};
+ PrimExpr cond = const_true();
+ for (int idx = 0; idx < size; ++idx) {
+ reductions.push_back(ReduceNode::make(reduce->combiner,
+ factor_exprs, axis, cond, idx));
+ }
+ return reductions;
+ }, reduce_stage->op->name + ".repl");
+
+ std::unordered_map<Tensor, Tensor> vmap;
+ std::unordered_map<Tensor, Tensor> rvmap;
+ for (int idx = 0; idx < size; ++idx) {
+ vmap[old_tensors[idx]] = repl_tensors[idx];
+ rvmap[repl_tensors[idx]] = old_tensors[idx];
+ }
+ ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
+ // revamp the reduction stage.
+ reduce_stage->op = repl_tensors[0]->op;
+ reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
+ reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
+ reduce_stage->relations = Array<IterVarRelation>();
+ return factor_tensors;
+}
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file schedule_lang.cc
+ */
+#include <tvm/top/schedule.h>
+#include <tvm/top/operation.h>
+#include <unordered_set>
+#include "graph.h"
+
+namespace tvm {
+namespace top {
+
+// find first occurance location in leaf
+template<typename T>
+size_t FindNodeRef(ArrayNode* array_node, const T& v) {
+ const Object* n = v.get();
+ for (size_t i = 0; i < array_node->data.size(); ++i) {
+ if (array_node->data[i].get() == n) return i;
+ }
+ return array_node->data.size();
+}
+
+size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) {
+ size_t pos = FindNodeRef(leaf_vars, v);
+ if (pos < leaf_vars->data.size()) return pos;
+
+ if (FindNodeRef(all_vars, v) < all_vars->data.size()) {
+ LOG(FATAL) << "Operate on iter var " << v
+ << "that has already been split";
+ } else {
+ LOG(FATAL) << "Operate on iter var " << v
+ << "that is not part of the schedule";
+ }
+ return 0;
+}
+
+void Split(StageNode* self,
+ IterVar parent,
+ PrimExpr factor,
+ PrimExpr nparts,
+ IterVar* p_outer,
+ IterVar* p_inner) {
+ // Check if split is valid.
+ CHECK(parent->iter_type == kDataPar ||
+ parent->iter_type == kCommReduce ||
+ parent->iter_type == kOrdered)
+ << "Cannot split on " << IterVarType2String(parent->iter_type);
+ IterVar outer = IterVarNode::make(
+ Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
+ IterVar inner = IterVarNode::make(
+ Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
+ *p_outer = outer;
+ *p_inner = inner;
+ // The splits
+ ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+ ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+ size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
+ self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts));
+ // add vars to all vars
+ all_vars->data.push_back(outer);
+ all_vars->data.push_back(inner);
+ // replace the position.
+ leaf_vars->data.erase(leaf_vars->data.begin() + pos);
+ leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner);
+ leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer);
+}
+
+Stage::Stage(Operation op) {
+ auto n = make_object<StageNode>();
+ n->op = op;
+ n->origin_op = op;
+ n->all_iter_vars = op->root_iter_vars();
+ // remove opaque var from leaf.
+ Array<IterVar> clean;
+ for (IterVar iv : n->all_iter_vars) {
+ if (iv->iter_type != kOpaque) clean.push_back(iv);
+ }
+ if (clean.size() == n->all_iter_vars.size()) {
+ n->leaf_iter_vars = n->all_iter_vars;
+ } else {
+ n->leaf_iter_vars = clean;
+ }
+ data_ = std::move(n);
+}
+
+bool Stage::is_scheduled() const {
+ const StageNode* n = operator->();
+ return !(n->relations.empty() && n->attach_type == kGroupRoot &&
+ n->all_iter_vars.same_as(n->leaf_iter_vars));
+}
+
+Stage Stage::GetAttachSpec() const {
+ Stage attach_spec = *this;
+ while (attach_spec->attach_type == kGroupRoot &&
+ attach_spec->group.defined()) {
+ attach_spec = attach_spec->group;
+ }
+ return attach_spec;
+}
+
+Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
+ (*this)->scope = scope;
+ return *this;
+}
+
+Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
+ CHECK_NE((*this)->attach_type, kScanUpdate)
+ << "Cannot specify compute_at for scan updates";
+ // Group constraint checking.
+ Stage group = (*this)->group;
+ if (group.defined()) {
+ Stage pg = parent->group;
+ while (pg.defined() && !pg.same_as(group)) {
+ pg = pg->group;
+ }
+ CHECK(pg.same_as(group))
+ << "Can only assign compute_at to stages within the same group";
+ }
+
+ (*this)->attach_type = kScope;
+ (*this)->attach_ivar = scope;
+ (*this)->attach_stage = parent;
+ bool found = false;
+ for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
+ if (scope == parent->leaf_iter_vars[i]) {
+ found = true; break;
+ }
+ }
+ CHECK(found)
+ << "Cannot find the axis " << scope
+ << " in parent's leaf_iter_vars"
+ << " parent=" << parent;
+ return *this;
+}
+
+Stage& Stage::compute_inline() { // NOLINT(*)
+ CHECK_NE((*this)->attach_type, kScanUpdate)
+ << "Cannot specify compute_at for scan updates";
+ (*this)->attach_type = kInline;
+ return *this;
+}
+
+Stage& Stage::compute_root() { // NOLINT(*)
+ CHECK_NE((*this)->attach_type, kScanUpdate)
+ << "Cannot specify compute_at for scan updates";
+ (*this)->attach_type = kGroupRoot;
+ return *this;
+}
+
+Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*)
+ StageNode* self = operator->();
+ CHECK(ivar->iter_type == kDataPar ||
+ ivar->iter_type == kCommReduce)
+ << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread";
+ CHECK(thread_ivar->iter_type == kThreadIndex)
+ << "Cannot rebase by " << IterVarType2String(ivar->iter_type)
+ << ", only thread axis is allowed so far";
+ ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+ ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+ FindLeafVar(all_vars, leaf_vars, ivar);
+
+ auto it = self->iter_var_attrs.find(ivar);
+ ObjectPtr<IterVarAttrNode> n;
+ if (it != self->iter_var_attrs.end()) {
+ n = make_object<IterVarAttrNode>(*(*it).second.operator->());
+ if (n->bind_thread.defined() &&
+ !n->bind_thread.same_as(thread_ivar)) {
+ LOG(WARNING) << "Axis " << ivar
+ << " is already bind to another thread " << n->bind_thread;
+ }
+ } else {
+ n = make_object<IterVarAttrNode>();
+ }
+ n->bind_thread = thread_ivar;
+ self->iter_var_attrs.Set(ivar, IterVarAttr(n));
+ return *this;
+}
+
+Stage& Stage::env_threads(Array<IterVar> threads) {
+ StageNode* self = operator->();
+ CHECK(self->op.defined() && self->op.as<ScanOpNode>())
+ << "env_threads is only valid for composite ops such as ScanOp";
+ CHECK_EQ(self->env_threads.size(), 0U)
+ << "Already set env_threads";
+ ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+ ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+ std::vector<ObjectRef> temp;
+ for (IterVar iv : threads) {
+ temp.push_back(iv);
+ }
+ leaf_vars->data.insert(
+ leaf_vars->data.begin(), temp.begin(), temp.end());
+ all_vars->data.insert(
+ all_vars->data.end(), temp.begin(), temp.end());
+ self->env_threads = threads;
+ return *this;
+}
+
+Stage& Stage::set_store_predicate(PrimExpr predicate) {
+ StageNode* self = operator->();
+ self->store_predicate = predicate;
+ return *this;
+}
+
+Stage& Stage::split(
+ IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
+ Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
+ return *this;
+}
+
+Stage& Stage::split_by_nparts(
+ IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
+ Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
+ return *this;
+}
+
+Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*)
+ StageNode* self = operator->();
+ CHECK(outer->iter_type == kDataPar ||
+ outer->iter_type == kCommReduce ||
+ outer->iter_type == kOrdered)
+ << "Cannot fuse " << IterVarType2String(outer->iter_type);
+ CHECK(inner->iter_type == kDataPar ||
+ inner->iter_type == kCommReduce ||
+ inner->iter_type == kOrdered)
+ << "Cannot fuse " << IterVarType2String(inner->iter_type);
+
+ IterVarType iter_type = outer->iter_type;
+ if (inner->iter_type > iter_type) iter_type = inner->iter_type;
+ std::string fused_name =
+ outer->var->name_hint + "." + inner->var->name_hint + ".fused";
+
+ IterVar fused = IterVarNode::make(
+ Range(), Var(fused_name, outer->var.dtype()), iter_type);
+
+ ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+ ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+
+ size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner);
+ size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer);
+ if (pos_inner + 1 == pos_outer) {
+ std::swap(outer, inner);
+ std::swap(pos_inner, pos_outer);
+ }
+ self->relations.push_back(FuseNode::make(outer, inner, fused));
+ all_vars->data.push_back(fused);
+ CHECK_EQ(pos_inner, pos_outer + 1)
+ << "Can only fuse iterations that are consecutive between each other";
+ leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
+ leaf_vars->data.begin() + pos_inner + 1);
+ leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
+ fused);
+ *p_target = fused;
+ return *this;
+}
+
+Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(*)
+ if (axes.size() != 0) {
+ IterVar fused = axes[0];
+ for (size_t i = 1; i < axes.size(); ++i) {
+ this->fuse(fused, axes[i], &fused);
+ }
+ *p_target = std::move(fused);
+ } else {
+ StageNode* self = operator->();
+ // special handle fuse empty array.
+ // insert at the outer most loop
+ IterVar singleton = IterVarNode::make(
+ Range::make_by_min_extent(0, 1),
+ Var("singleton", DataType::Int(32)), kDataPar);
+ self->relations.push_back(SingletonNode::make(singleton));
+ ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+ ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+ all_vars->data.push_back(singleton);
+ leaf_vars->data.insert(leaf_vars->data.begin(), singleton);
+ *p_target = singleton;
+ }
+ return *this;
+}
+
+Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
+ std::unordered_set<IterVar> seen_var;
+ StageNode* self = operator->();
+ for (IterVar iv : order) {
+ CHECK(iv->iter_type == kDataPar ||
+ iv->iter_type == kCommReduce ||
+ iv->iter_type == kThreadIndex)
+ << "Cannot reorder IterVar("
+ << IterVarType2String(iv->iter_type) << ")";
+
+ CHECK_EQ(seen_var.count(iv), 0)
+ << "Same axis can not appear more than once " << iv;
+ seen_var.insert(iv);
+ }
+ ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+ ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+ std::vector<size_t> pos;
+
+ for (size_t i = 0; i < order.size(); ++i) {
+ pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
+ }
+ std::vector<ObjectRef> temp;
+ for (size_t i = 0; i < pos.size(); ++i) {
+ temp.emplace_back(leaf_vars->data[pos[i]]);
+ }
+ std::sort(pos.begin(), pos.end());
+ for (size_t i = 0; i < pos.size(); ++i) {
+ leaf_vars->data[pos[i]] = temp[i];
+ }
+ return *this;
+}
+
+Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
+ PrimExpr x_factor, PrimExpr y_factor,
+ IterVar* p_x_outer, IterVar* p_y_outer,
+ IterVar* p_x_inner, IterVar* p_y_inner) {
+ split(x_parent, x_factor, p_x_outer, p_x_inner);
+ split(y_parent, y_factor, p_y_outer, p_y_inner);
+ reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
+ return *this;
+}
+
+template<typename FUpdate>
+inline void UpdateIterVarAttr(StageNode* self,
+ IterVar var,
+ FUpdate fupdate,
+ bool need_leaf = true) {
+ if (need_leaf) {
+ ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+ ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+ FindLeafVar(all_vars, leaf_vars, var);
+ }
+ auto it = self->iter_var_attrs.find(var);
+ ObjectPtr<IterVarAttrNode> n;
+ if (it != self->iter_var_attrs.end()) {
+ n = make_object<IterVarAttrNode>(*(*it).second.operator->());
+ } else {
+ n = make_object<IterVarAttrNode>();
+ }
+ fupdate(n.get());
+ self->iter_var_attrs.Set(var, IterVarAttr(n));
+}
+
+inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
+ UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) {
+ n->iter_type = iter_type;
+ });
+}
+
+Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
+ CHECK(var->iter_type == kDataPar ||
+ var->iter_type == kOpaque ||
+ var->iter_type == kUnrolled ||
+ var->iter_type == kVectorized ||
+ var->iter_type == kTensorized ||
+ var->iter_type == kParallelized)
+ << "Cannot vectorize on " << IterVarType2String(var->iter_type);
+ SetAttrIterType(operator->(), var, kVectorized);
+ return *this;
+}
+
+Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*)
+ UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) {
+ n->iter_type = kTensorized;
+ n->tensor_intrin = f;
+ });
+ return *this;
+}
+
+Stage& Stage::unroll(IterVar var) { // NOLINT(*)
+ SetAttrIterType(operator->(), var, kUnrolled);
+ return *this;
+}
+
+Stage& Stage::parallel(IterVar var) { // NOLINT(*)
+ SetAttrIterType(operator->(), var, kParallelized);
+ return *this;
+}
+
+Stage& Stage::pragma(IterVar var,
+ const std::string& pragma_type,
+ const PrimExpr& pragma_value) { // NOLINT(*)
+ if (pragma_type == "unroll") {
+ this->unroll(var);
+ } else if (pragma_type == "vectorize") {
+ this->vectorize(var);
+ } else {
+ UpdateIterVarAttr(
+ operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
+ n->pragma_keys.push_back(ir::StringImmNode::make(pragma_type));
+ n->pragma_values.push_back(pragma_value);
+ });
+ }
+ return *this;
+}
+
+Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) {
+ StageNode *self = operator->();
+ ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+ ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+ FindLeafVar(all_vars, leaf_vars, var);
+ auto it = self->iter_var_attrs.find(var);
+ ObjectPtr<IterVarAttrNode> n;
+ if (it != self->iter_var_attrs.end()) {
+ n = make_object<IterVarAttrNode>(*(*it).second.operator->());
+ } else {
+ n = make_object<IterVarAttrNode>();
+ }
+ n->prefetch_data.push_back(tensor);
+ n->prefetch_offset.push_back(offset);
+ self->iter_var_attrs.Set(var, IterVarAttr(n));
+ return *this;
+}
+
+Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
+ StageNode *self = operator->();
+ UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) {
+ n->dim_align_factor = factor;
+ n->dim_align_offset = offset;
+ }, false);
+ return *this;
+}
+
+Stage& Stage::double_buffer() {
+ StageNode *self = operator->();
+ CHECK(!self->is_output) << "Cannot apply double buffer on output";
+ self->double_buffer = true;
+ return *this;
+}
+
+Stage& Stage::opengl() {
+ CHECK(!is_scheduled()) << "Must be a fresh schedule";
+ StageNode *self = operator->();
+
+ auto all_iter_vars = self->all_iter_vars; // curr version of all_iter_vars
+ CHECK(!all_iter_vars.empty()) << "At least one iter var";
+
+ // Fuse all data parallel dimensions to 1.
+ IterVar fused = all_iter_vars[0];
+ for (size_t i = 1; i != all_iter_vars.size(); ++i) {
+ auto iter_var = all_iter_vars[i];
+ switch (iter_var->iter_type) {
+ case IterVarType::kDataPar: {
+ fuse(fused, all_iter_vars[i], &fused);
+ break;
+ }
+ case IterVarType::kThreadIndex: {
+ LOG(ERROR) << "A fresh schedule shouldn't have thread index iter var";
+ break;
+ }
+ case IterVarType::kCommReduce:
+ case IterVarType::kOrdered:
+ case IterVarType::kOpaque: {
+ break;
+ }
+ default: {
+ LOG(ERROR) << "Invalid iter var type "
+ << IterVarType2String(iter_var->iter_type);
+ break;
+ }
+ }
+ }
+
+ // Bind the only dimension to threadIdx.x.
+ bind(fused, thread_axis(Range(nullptr), "threadIdx.x"));
+
+ // Mark this stage as OpenGL.
+ (*this)->is_opengl = true;
+
+ return *this;
+}
+
+Stage CopyStage(const Stage& s) {
+ ObjectPtr<StageNode> n =
+ make_object<StageNode>(*s.operator->());
+ return Stage(n);
+}
+
+Schedule Schedule::copy() const {
+ // map of stages.
+ const ScheduleNode* self = operator->();
+ std::unordered_map<Stage, Stage, ObjectHash, ObjectEqual> smap;
+ ObjectPtr<ScheduleNode> n = make_object<ScheduleNode>();
+ n->outputs = self->outputs;
+ // Copy the stages.
+ for (Stage s : self->stages) {
+ Stage scopy = CopyStage(s);
+ smap[s] = scopy;
+ n->stages.push_back(scopy);
+ }
+ for (Stage g : self->groups) {
+ Stage gcopy = CopyStage(g);
+ smap[g] = gcopy;
+ n->groups.push_back(gcopy);
+ }
+ // Remaps the reference relations.
+ for (auto kv : self->stage_map) {
+ n->stage_map.Set(kv.first, smap.at(kv.second));
+ }
+ for (Stage s : n->stages) {
+ if (s->attach_stage.defined()) {
+ CHECK(smap.find(s->attach_stage) != smap.end())
+ << s->attach_stage << " not found in " << (*this);
+ s->attach_stage = smap.at(s->attach_stage);
+ }
+ if (s->group.defined()) {
+ CHECK(smap.find(s->group) != smap.end())
+ << s->group << " not found in " << (*this);
+ s->group = smap.at(s->group);
+ }
+ }
+ for (Stage s : n->groups) {
+ if (s->attach_stage.defined()) {
+ CHECK(smap.find(s->attach_stage) != smap.end())
+ << s->attach_stage << " not found in " << (*this);
+ s->attach_stage = smap.at(s->attach_stage);
+ }
+ if (s->group.defined()) {
+ CHECK(smap.find(s->group) != smap.end())
+ << s->group << " not found in " << (*this);
+ s->group = smap.at(s->group);
+ }
+ }
+ return Schedule(n);
+}
+
+Stage Schedule::operator[](const Operation& op) {
+ auto it = (*this)->stage_map.find(op);
+ CHECK(it != (*this)->stage_map.end())
+ << "Cannot find Stage for operator " << op
+ << " in the schedule";
+ return (*it).second;
+}
+
+Stage LeastCommonAncestor(Stage g1, Stage g2) {
+ if (!g1.defined()) return g1;
+ if (!g2.defined()) return g2;
+ if (g1.same_as(g2)) return g1;
+ Stage g = g1;
+ while (g.defined()) {
+ if (g.same_as(g2)) return g2;
+ g = g->group;
+ }
+ g = g2;
+ while (g.defined()) {
+ if (g.same_as(g1)) return g1;
+ g = g->group;
+ }
+ return g;
+}
+
+Array<Tensor> RemapTensor(ScheduleNode* self,
+ const Array<Tensor>& arr) {
+ self->InitCache();
+ const auto& op2stage_cache = self->op2stage_cache_;
+ Array<Tensor> ret;
+ for (Tensor t : arr) {
+ if (!op2stage_cache.count(t->op.get())) {
+ CHECK(self->stage_map.count(t->op))
+ << "Given tensor is not in the schedule plan";
+ t = self->stage_map[t->op]->op.output(t->value_index);
+ }
+ ret.push_back(t);
+ }
+ return ret;
+}
+
+// Group the schedule stages.
+Stage Schedule::create_group(const Array<Tensor>& outputs,
+ const Array<Tensor>& inputs,
+ bool include_inputs) {
+ ScheduleNode* self = operator->();
+ self->InitCache();
+ const auto& op2stage_cache = self->op2stage_cache_;
+ // Get the ops.
+ Array<Operation> ops = top::GetSubGraph(
+ RemapTensor(self, outputs),
+ RemapTensor(self, inputs),
+ include_inputs);
+ // local counter entry
+ // Automatically initialize to 0 during creation.
+ struct Entry {
+ int count{0};
+ };
+ // Map of group->touched counter
+ std::unordered_map<Stage, Entry, ObjectHash, ObjectEqual> counter;
+ // The parent group;
+ Stage parent_group;
+ // Detect common parent and child.
+ for (size_t i = 0; i < ops.size(); ++i) {
+ Operation op = ops[i];
+ auto it = op2stage_cache.find(op.get());
+ CHECK(it != op2stage_cache.end());
+ Stage op_group = it->second->group;
+ if (i == 0) {
+ parent_group = op_group;
+ } else {
+ parent_group = LeastCommonAncestor(parent_group, op_group);
+ }
+ if (op_group.defined()) {
+ ++counter[op_group].count;
+ }
+ }
+ // Create the new group stage.
+ Stage gstage(make_object<StageNode>());
+ gstage->group = parent_group;
+ if (parent_group.defined()) {
+ ++parent_group->num_child_stages;
+ }
+ // Propagate the counter statistics from by checking if subgroup
+ // Is full and propagate.
+ std::vector<Stage> stack;
+ for (auto &kv : counter) {
+ if (!kv.first.same_as(parent_group)) {
+ if (kv.first->num_child_stages == kv.second.count) {
+ stack.push_back(kv.first);
+ }
+ }
+ }
+ while (!stack.empty()) {
+ Stage g = stack.back();
+ stack.pop_back();
+ if (g->group.defined() && !g->group.same_as(parent_group)) {
+ Entry& e = counter[g->group];
+ ++e.count;
+ if (e.count == g->group->num_child_stages) {
+ stack.push_back(g->group);
+ }
+ }
+ }
+ // Verification and remappig the subgroups.
+ for (auto &kv : counter) {
+ if (kv.first.same_as(parent_group)) continue;
+ CHECK_EQ(kv.first->num_child_stages, kv.second.count)
+ << "Trying to group region that intersect with an already existed group";
+ if (kv.first->group.same_as(parent_group)) {
+ Stage s = kv.first;
+ s->group = gstage;
+ ++gstage->num_child_stages;
+ if (parent_group.defined()) {
+ --parent_group->num_child_stages;
+ }
+ }
+ }
+ // Remap the group of op stages.
+ for (Operation op : ops) {
+ auto it = op2stage_cache.find(op.get());
+ CHECK(it != op2stage_cache.end());
+ Stage s = it->second;
+ if (s->group.same_as(parent_group)) {
+ s->group = gstage;
+ ++gstage->num_child_stages;
+ if (parent_group.defined()) {
+ --parent_group->num_child_stages;
+ }
+ }
+ }
+ // Correct the attach to keep everything in group.
+ for (Operation op : ops) {
+ auto it = op2stage_cache.find(op.get());
+ CHECK(it != op2stage_cache.end());
+ Stage s = it->second;
+ if (s->attach_type == kScope) {
+ Stage cg = LeastCommonAncestor(s->attach_stage->group, gstage);
+ if (!cg.same_as(gstage)) {
+ LOG(WARNING) << "group invalidates some previous compute_at relation "
+ << " and keeps things to be computed inside the group";
+ s.compute_root();
+ }
+ }
+ }
+
+ self->groups.push_back(gstage);
+ return gstage;
+}
+
+void ScheduleNode::InvalidateCache() {
+ op2stage_cache_.clear();
+}
+
+void ScheduleNode::InitCache() {
+ if (op2stage_cache_.size() == stages.size()) return;
+ InvalidateCache();
+ for (Stage s : stages) {
+ if (s->op.defined()) {
+ op2stage_cache_[s->op.get()] = s;
+ }
+ }
+ CHECK_EQ(op2stage_cache_.size(), stages.size());
+}
+
+bool ScheduleNode::Contain(const Operation& op) const {
+ return stage_map.find(op) != stage_map.end();
+}
+
+Schedule ScheduleNode::make(Array<Operation> ops) {
+ auto n = make_object<ScheduleNode>();
+ Schedule sch(n);
+ n->outputs = ops;
+ auto g = top::CreateReadGraph(n->outputs);
+ Array<Operation> post_order = top::PostDFSOrder(n->outputs, g);
+ // output set.
+ std::unordered_set<Operation> output_set;
+ for (Operation x : ops) {
+ output_set.insert(x);
+ }
+ for (Operation op : post_order) {
+ Stage stage(op);
+ stage->is_output = output_set.count(op) != 0;
+ n->stages.push_back(stage);
+ n->stage_map.Set(op, stage);
+ // mark scan updates.
+ if (const ScanOpNode* scan = op.as<ScanOpNode>()) {
+ Array<Tensor> inputs;
+ for (Tensor t : scan->state_placeholder) {
+ inputs.push_back(t);
+ }
+ for (Tensor t : scan->inputs) {
+ inputs.push_back(t);
+ }
+ // Create the scan group.
+ Stage scan_group = sch.create_group(scan->update, inputs, false);
+ scan_group->attach_type = kScanUpdate;
+ scan_group->attach_stage = stage;
+
+ for (size_t i = 0; i < scan->update.size(); ++i) {
+ Stage s = n->stage_map[scan->update[i]->op];
+ CHECK(scan_group.same_as(s->group));
+ }
+ }
+ }
+ return sch;
+}
+
+IterVarRelation SplitNode::make(IterVar parent,
+ IterVar outer,
+ IterVar inner,
+ PrimExpr factor,
+ PrimExpr nparts) {
+ auto n = make_object<SplitNode>();
+ n->parent = parent;
+ n->outer = outer;
+ n->inner = inner;
+ n->factor = factor;
+ n->nparts = nparts;
+ return IterVarRelation(n);
+}
+
+IterVarRelation FuseNode::make(
+ IterVar outer, IterVar inner, IterVar fused) {
+ auto n = make_object<FuseNode>();
+ n->outer = outer;
+ n->inner = inner;
+ n->fused = fused;
+ return IterVarRelation(n);
+}
+
+IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
+ auto n = make_object<RebaseNode>();
+ n->parent = parent;
+ n->rebased = rebased;
+ return IterVarRelation(n);
+}
+
+IterVarRelation SingletonNode::make(IterVar iter) {
+ auto n = make_object<SingletonNode>();
+ n->iter = iter;
+ return IterVarRelation(n);
+}
+
+TVM_REGISTER_NODE_TYPE(StageNode);
+TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
+TVM_REGISTER_NODE_TYPE(SplitNode);
+TVM_REGISTER_NODE_TYPE(FuseNode);
+TVM_REGISTER_NODE_TYPE(RebaseNode);
+TVM_REGISTER_NODE_TYPE(SingletonNode);
+TVM_REGISTER_NODE_TYPE(ScheduleNode);
+
+// Printer
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<StageNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const StageNode*>(node.get());
+ if (op->op.defined()) {
+ p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
+ } else {
+ p->stream << "group-stage(" << op << ")";
+ }
+})
+.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const IterVarAttrNode*>(node.get());
+ p->stream << IterVarType2String(op->iter_type);
+})
+.set_dispatch<SplitNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const SplitNode*>(node.get());
+ p->stream << "split(parent=";
+ p->Print(op->parent);
+ p->stream << ", outer=";
+ p->Print(op->outer);
+ p->stream << ", inner=";
+ p->Print(op->inner);
+ p->stream << ')';
+})
+.set_dispatch<FuseNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const FuseNode*>(node.get());
+ p->stream << "split(";
+ p->stream << "outer=";
+ p->Print(op->outer);
+ p->stream << ", inner=";
+ p->Print(op->inner);
+ p->stream << ", fused=";
+ p->Print(op->fused);
+ p->stream << ')';
+})
+.set_dispatch<RebaseNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const RebaseNode*>(node.get());
+ p->stream << "rebase(";
+ p->stream << "parent=";
+ p->Print(op->parent);
+ p->stream << ", rebased=";
+ p->Print(op->rebased);
+ p->stream << ')';
+})
+.set_dispatch<SingletonNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const SingletonNode*>(node.get());
+ p->stream << "singleton(";
+ p->Print(op->iter);
+ p->stream << ')';
+})
+.set_dispatch<ScheduleNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const ScheduleNode*>(node.get());
+ p->stream << "schedule(" << op << ")";
+ });
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file schedule_ops.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/top/operation.h>
+#include <tvm/top/schedule_pass.h>
+#include <utility>
+#include <unordered_map>
+#include <unordered_set>
+#include "graph.h"
+#include "../operation/op_util.h"
+#include "../../pass/ir_util.h"
+
+namespace tvm {
+namespace top {
+
+using namespace ir;
+
+Stmt MakePipeline(const Stage& s,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ Stmt consumer,
+ bool debug_keep_trivial_loop) {
+ Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
+ if (producer.defined()) {
+ producer = ProducerConsumerNode::make(s->op, true, producer);
+ }
+ if (s->double_buffer) {
+ producer = AttrStmtNode::make(
+ s->op, ir::attr::double_buffer_scope, 1, producer);
+ }
+ Stmt pipeline = producer;
+
+ if (consumer.defined() && !is_no_op(consumer)) {
+ consumer = ProducerConsumerNode::make(s->op, false, consumer);
+ pipeline = SeqStmt({producer, consumer});
+ }
+ pipeline = s->op->BuildRealize(s, dom_map, pipeline);
+ // use attribute to mark scope of the operation.
+ pipeline = AttrStmtNode::make(
+ s->op, ir::attr::realize_scope,
+ StringImmNode::make(s->scope),
+ pipeline);
+
+ if (s->is_opengl) {
+ pipeline = AttrStmtNode::make(
+ s->op, ir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline);
+ }
+ return pipeline;
+}
+
+// inject the operator's realization on the stmt.
+class InjectAttach : public StmtMutator {
+ public:
+ InjectAttach(const Stage& stage,
+ const Stage& attach_spec,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool debug_keep_trivial_loop)
+ : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
+ debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
+
+ Stmt VisitStmt(const Stmt& input_stmt) final {
+ CHECK(input_stmt.defined());
+ auto stmt = StmtMutator::VisitStmt(input_stmt);
+ const AttrStmtNode* op = stmt.as<AttrStmtNode>();
+ if (op != nullptr &&
+ op->attr_key == attr::loop_scope) {
+ if (attach_spec_->attach_type == kScope &&
+ op->node == attach_spec_->attach_ivar) {
+ CHECK(!found_attach)
+ << "Find IterVar" << attach_spec_->attach_ivar
+ << " in multiple places in the IR";
+ found_attach = true;
+ stmt = AttrStmtNode::make(
+ op->node, op->attr_key, op->value,
+ MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
+ }
+ }
+ return stmt;
+ }
+ // whether attach point is found
+ bool found_attach{false};
+
+ private:
+ // The stage.
+ const Stage& stage_;
+ // The attach spec, may not contain op.
+ const Stage& attach_spec_;
+ // domain map
+ const std::unordered_map<IterVar, Range>& dom_map_;
+ // Whether keep trivial loops with extent of 1 during lowering.
+ // This is a debug feature for dataflow/axis analysis
+ bool debug_keep_trivial_loop_;
+};
+
+// inject the operator's realization on the stmt.
+class InjectScanStep : public StmtMutator {
+ public:
+ InjectScanStep(const Stage& stage,
+ const Operation& scan_op,
+ const std::unordered_map<IterVar, Range>& dom_map,
+ bool is_init,
+ bool debug_keep_trivial_loop)
+ : stage_(stage), scan_op_(scan_op),
+ dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
+
+ Stmt VisitStmt(const Stmt& input_stmt) final {
+ CHECK(input_stmt.defined());
+ auto stmt = StmtMutator::VisitStmt(input_stmt);
+ // update
+ const AttrStmtNode* op = stmt.as<AttrStmtNode>();
+ if (op != nullptr &&
+ ((op->attr_key == attr::scan_update_scope && !is_init_) ||
+ (op->attr_key == attr::scan_init_scope && is_init_))) {
+ if (op->node.same_as(scan_op_)) {
+ found_attach = true;
+ stmt = AttrStmtNode::make(
+ op->node, op->attr_key, op->value,
+ MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
+ }
+ }
+ return stmt;
+ }
+
+ // whether attach point is found
+ bool found_attach{false};
+
+ private:
+ // the operations to be carried
+ const Stage& stage_;
+ const Operation& scan_op_;
+ // domain map
+ const std::unordered_map<IterVar, Range>& dom_map_;
+ // whether it is init.
+ bool is_init_;
+ // Whether keep trivial loops with extent of 1 during lowering.
+ // This is a debug feature for dataflow/axis analysis
+ bool debug_keep_trivial_loop_;
+};
+
+// Postprocessing of schedule op
+// Replace the init and update's expression by scan's buffer.
+class SchedulePostProc : public StmtExprMutator {
+ public:
+ Stmt VisitStmt_(const ProducerConsumerNode* op) final {
+ auto it = replace_op_.find(op->func.get());
+ if (it != replace_op_.end()) {
+ Stmt body = this->VisitStmt(op->body);
+ if (it->second.defined()) {
+ return ProducerConsumerNode::make(
+ it->second, op->is_producer, body);
+ } else {
+ return body;
+ }
+ } else {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ }
+ Stmt VisitStmt_(const LetStmtNode* op) final {
+ if (!HasSideEffect(op->value)) {
+ var_value_[op->var.get()] = this->VisitExpr(op->value);
+ return this->VisitStmt(op->body);
+ } else {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ if (op->attr_key == attr::loop_scope ||
+ op->attr_key == attr::scan_init_scope) {
+ return this->VisitStmt(op->body);
+ } else if (op->attr_key == attr::scan_update_scope) {
+ const ScanOpNode* scan = op->node.as<ScanOpNode>();
+ CHECK(scan);
+ var_value_[scan->scan_axis->var.get()] = op->value;
+ return this->VisitStmt(op->body);
+ } else if (op->attr_key == attr::thread_extent) {
+ // delete duplicated thread extent attr
+ auto it = thread_extent_scope_.find(op->node.get());
+ if (it != thread_extent_scope_.end()) {
+ CHECK(is_zero(ir::Simplify(it->second - op->value)));
+ return this->VisitStmt(op->body);
+ } else {
+ thread_extent_scope_[op->node.get()] = op->value;
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ thread_extent_scope_.erase(op->node.get());
+ return ret;
+ }
+ } else if (op->attr_key == ir::attr::realize_scope ||
+ op->attr_key == ir::attr::double_buffer_scope) {
+ auto it = replace_op_.find(op->node.get());
+ if (it != replace_op_.end()) {
+ if (it->second.defined()) {
+ Stmt ret = AttrStmtNode::make(
+ it->second, op->attr_key, op->value, op->body);
+ return this->VisitStmt(ret);
+ } else {
+ return this->VisitStmt(op->body);
+ }
+ }
+ } else if (op->attr_key == ir::attr::buffer_bind_scope) {
+ Array<ObjectRef> tuple = Downcast<Array<ObjectRef> >(op->node);
+ Tensor tensor = Downcast<Tensor>(tuple[1]);
+ auto it = replace_op_.find(tensor->op.get());
+ if (it != replace_op_.end()) {
+ if (it->second.defined()) {
+ return AttrStmtNode::make(
+ Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
+ op->attr_key, op->value, this->VisitStmt(op->body));
+ } else {
+ return this->VisitStmt(op->body);
+ }
+ }
+ } else if (op->attr_key == ir::attr::buffer_dim_align) {
+ Tensor tensor = Downcast<Tensor>(op->node);
+ auto it = replace_op_.find(tensor->op.get());
+ if (it != replace_op_.end()) {
+ if (it->second.defined()) {
+ return AttrStmtNode::make(
+ it->second.output(tensor->value_index),
+ op->attr_key, op->value, this->VisitStmt(op->body));
+ } else {
+ return this->VisitStmt(op->body);
+ }
+ }
+ }
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ Stmt VisitStmt_(const RealizeNode* op) final {
+ TensorKey key{op->func, op->value_index};
+ auto it = replace_realize_.find(key);
+ if (it != replace_realize_.end()) {
+ if (it->second.defined()) {
+ Stmt ret = RealizeNode::make(
+ it->second->op, it->second->value_index,
+ op->dtype, op->bounds, op->condition, op->body);
+ return this->VisitStmt(ret);
+ } else {
+ return this->VisitStmt(op->body);
+ }
+ } else {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ }
+
+ Stmt VisitStmt_(const ProvideNode* op) final {
+ TensorKey key{op->func, op->value_index};
+ auto it = replace_buffer_.find(key);
+ if (it != replace_buffer_.end()) {
+ const Tensor& dst = it->second;
+ Stmt ret = ProvideNode::make(
+ dst->op, dst->value_index, op->value, op->args);
+ return this->VisitStmt(ret);
+ } else {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ }
+
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ if (op->call_type == CallNode::Halide) {
+ TensorKey key{op->func, op->value_index};
+ auto it = replace_buffer_.find(key);
+ if (it != replace_buffer_.end()) {
+ const Tensor& dst = it->second;
+ PrimExpr ret = CallNode::make(
+ op->dtype, dst->op->name, op->args,
+ op->call_type, dst->op, dst->value_index);
+ return this->VisitExpr(ret);
+ }
+ }
+ return StmtExprMutator::VisitExpr_(op);
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ auto it = var_value_.find(op);
+ if (it != var_value_.end()) {
+ return it->second;
+ } else {
+ return GetRef<PrimExpr>(op);
+ }
+ }
+
+ void Init(const Schedule& sch) {
+ for (Stage s : sch->stages) {
+ for (auto kv : s->iter_var_attrs) {
+ // Update bind thread information.
+ if (kv.second->bind_thread.defined()) {
+ const Var& from = kv.first->var;
+ const Var& to = kv.second->bind_thread->var;
+ CHECK(!var_value_.count(from.get()));
+ var_value_[from.get()] = to;
+ }
+ }
+ // This must be checked for all ops, including scan.
+ if (!s->op.same_as(s->origin_op)) {
+ for (int i = 0; i < s->op->num_outputs(); ++i) {
+ Tensor target = s->origin_op.output(i);
+ AddReplace(s->op.output(i), target,
+ target, s->origin_op);
+ }
+ }
+ // Specially add replacements for scan op.
+ if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
+ for (size_t i = 0; i < scan->update.size(); ++i) {
+ Tensor t = s->origin_op.output(i);
+ AddReplace(scan->init[i], t);
+ AddReplace(scan->update[i], t);
+ AddReplace(scan->state_placeholder[i], t);
+ }
+ }
+ }
+ }
+
+ private:
+ void AddReplace(Tensor src,
+ Tensor dst,
+ Tensor repl_realize = Tensor(),
+ Operation repl_op = Operation()) {
+ TensorKey key{src->op, src->value_index};
+ replace_buffer_[key] = dst;
+ replace_realize_[key] = repl_realize;
+ replace_op_[src->op.get()] = repl_op;
+ }
+ // The thread extent scope.
+ std::unordered_map<const Object*, PrimExpr> thread_extent_scope_;
+ // The scan value
+ std::unordered_map<const VarNode*, PrimExpr> var_value_;
+ // buffer replacement
+ std::unordered_map<TensorKey, Tensor> replace_buffer_;
+ // buffere realization to be replaced
+ std::unordered_map<TensorKey, Tensor> replace_realize_;
+ // replace producer consumer.
+ std::unordered_map<const Object*, Operation> replace_op_;
+};
+
+Stmt ScheduleOps(
+ Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
+ Stmt body = Stmt();
+ std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
+ // scan init and scan updates
+ std::unordered_map<Operation, Operation> scan_init;
+ for (Stage s : sch->stages) {
+ const ScanOpNode* scan = s->op.as<ScanOpNode>();
+ if (!scan) continue;
+ for (Tensor t : scan->init) {
+ if (scan_init.count(t->op)) {
+ CHECK(scan_init.at(t->op).same_as(s->op))
+ << "Scan init tensor can only belong to one scan";
+ } else {
+ scan_init[t->op] = s->op;
+ }
+ }
+ }
+ // verify correctness of group.
+ for (Stage g : sch->groups) {
+ CHECK(!g->op.defined());
+ CHECK_EQ(g->leaf_iter_vars.size(), 0U);
+ }
+ // reverse the post DFS order.
+ for (size_t i = sch->stages.size(); i != 0; --i) {
+ Stage s = sch->stages[i - 1];
+ CHECK_NE(s->attach_type, kInline)
+ << "call schedule.normalize before scheduleops";
+ CHECK(s->op.defined());
+ // no need to specify place holder op.
+ if (s->op.as<PlaceholderOpNode>()) continue;
+ // Remove grouping sugar, get the real attach spec.
+ Stage attach_spec = s.GetAttachSpec();
+
+ if (scan_init.count(s->op)) {
+ CHECK(body.defined());
+ InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
+ body = mu(std::move(body));
+ CHECK(mu.found_attach)
+ << "did not find attachment point for scan.init";
+ } else if (attach_spec->attach_type == kScanUpdate) {
+ // Handle scan update
+ CHECK(body.defined());
+ InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
+ body = mu(std::move(body));
+ CHECK(mu.found_attach)
+ << "did not find attachment point for scan.update";
+ } else if (attach_spec->attach_type == kInlinedAlready) {
+ // do nothing
+ } else if (attach_spec->attach_type == kGroupRoot) {
+ CHECK(!s->group.defined());
+ body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
+ } else {
+ CHECK_EQ(attach_spec->attach_type, kScope);
+ CHECK(body.defined());
+ InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
+ body = mutator(std::move(body));
+ CHECK(mutator.found_attach)
+ << "did not find attachment point for " << s << " in "
+ << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar
+ << ", body:\n"
+ << body;
+ }
+ }
+ SchedulePostProc post_proc;
+ post_proc.Init(sch);
+ return post_proc(std::move(body));
+}
+
+} // namespace top
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tensor.cc
+ */
+#include <tvm/top/tensor.h>
+#include <tvm/top/operation.h>
+#include <tvm/top/tensor_intrin.h>
+#include <memory>
+
+namespace tvm {
+namespace top {
+// Tensor
+PrimExpr Tensor::operator()(Array<Var> indices) const {
+ Array<PrimExpr> arr(indices.begin(), indices.end());
+ return operator()(arr);
+}
+
+PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
+ using ir::CallNode;
+ if (ndim() != 0) {
+ CHECK_EQ(ndim(), indices.size())
+ << "Tensor dimension mismatch in read"
+ << "ndim = " << ndim() << ", indices.size=" << indices.size();
+ }
+ auto n = CallNode::make(
+ (*this)->dtype, (*this)->op->name, indices, CallNode::Halide,
+ (*this)->op, (*this)->value_index);
+ return n;
+}
+
+Tensor Operation::output(size_t i) const {
+ auto node = make_object<TensorNode>();
+ node->op = *this;
+ node->value_index = i;
+ node->dtype = (*this)->output_dtype(i);
+ node->shape = (*this)->output_shape(i);
+ return Tensor(node);
+}
+
+Tensor TensorNode::make(Array<PrimExpr> shape,
+ DataType dtype,
+ Operation op,
+ int value_index) {
+ auto n = make_object<TensorNode>();
+ n->shape = std::move(shape);
+ n->dtype = dtype;
+ n->op = op;
+ n->value_index = value_index;
+ return Tensor(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<TensorNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* t = static_cast<const TensorNode*>(node.get());
+ p->stream << "Tensor(shape=" << t->shape
+ << ", op.name=" << t->op->name << ')';
+ });
+
+TVM_REGISTER_NODE_TYPE(TensorNode);
+
+
+// TensorIntrin
+
+TensorIntrin TensorIntrinNode::make(std::string name,
+ Operation op,
+ Array<Tensor> inputs,
+ Array<Buffer> buffers,
+ Array<Var> scalar_params,
+ Stmt body,
+ Stmt reduce_init,
+ Stmt reduce_update) {
+ auto n = make_object<TensorIntrinNode>();
+ n->name = std::move(name);
+ n->op = std::move(op);
+ n->inputs = std::move(inputs);
+ n->buffers = std::move(buffers);
+ n->scalar_params = std::move(scalar_params);
+ n->body = std::move(body);
+ n->reduce_init = std::move(reduce_init);
+ n->reduce_update = std::move(reduce_update);
+ return TensorIntrin(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const TensorIntrinNode*>(node.get());
+ p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
+ });
+
+TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
+
+
+// TensorIntrinCall
+
+TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
+ Array<Tensor> tensors,
+ Array<Region> regions,
+ Array<IterVar> reduce_axis,
+ Array<PrimExpr> scalar_inputs) {
+ auto n = make_object<TensorIntrinCallNode>();
+ n->intrin = std::move(intrin);
+ n->tensors = std::move(tensors);
+ n->regions = std::move(regions);
+ n->reduce_axis = std::move(reduce_axis);
+ n->scalar_inputs = std::move(scalar_inputs);
+ return TensorIntrinCall(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
+ p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
+ });
+
+TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
+
+} // namespace top
+} // namespace tvm
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <topi/cuda/injective.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/build_module.h>
TEST(BuildModule, Basic) {
using namespace tvm;
+ using namespace tvm::top;
auto n = var("n");
Array<PrimExpr> shape;
shape.push_back(n);
*/
using namespace tvm;
+ using namespace tvm::top;
const runtime::PackedFunc* pf = runtime::Registry::Get("module._Enabled");
bool enabled = (*pf)("cuda");
if (!enabled) {
#include <dmlc/logging.h>
#include <gtest/gtest.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
TEST(Expr, Basic) {
using namespace tvm;
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_pass.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
TEST(IRSIMPLIFY, MinMax) {
auto x = tvm::var("x");
#include <gtest/gtest.h>
#include <tvm/build_module.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/analysis.h>
*/
#include <gtest/gtest.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
TVM_REGISTER_GLOBAL("schedule")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_pass.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
TEST(SimplePasses, HasSideEffect) {
using namespace tvm;
Array<PrimExpr> shape;
shape.push_back(n);
- auto A = placeholder(shape, DataType::Float(32), "A");
+ auto A = top::placeholder(shape, DataType::Float(32), "A");
CHECK(!tvm::ir::HasSideEffect(A[0]));
}
#include <dmlc/logging.h>
#include <gtest/gtest.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
TEST(Tensor, Basic) {
using namespace tvm;
+ using namespace tvm::top;
+
Var m("m"), n("n"), l("l");
Tensor A = placeholder({m, l}, DataType::Float(32), "A");
TEST(Tensor, Reduce) {
using namespace tvm;
+ using namespace tvm::top;
+
Var m("m"), n("n"), l("l");
- Tensor A = placeholder({m, l}, DataType::Float(32), "A");
- Tensor B = placeholder({n, l}, DataType::Float(32), "B");
+ top::Tensor A = top::placeholder({m, l}, DataType::Float(32), "A");
+ top::Tensor B = top::placeholder({n, l}, DataType::Float(32), "B");
IterVar rv = reduce_axis(Range{0, l}, "k");
- auto C = compute({m, n}, [&](Var i, Var j) {
+ auto C = top::compute({m, n}, [&](Var i, Var j) {
return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv});
}, "C");
- LOG(INFO) << C->op.as<ComputeOpNode>()->body;
+ LOG(INFO) << C->op.as<top::ComputeOpNode>()->body;
}
int main(int argc, char ** argv) {
* under the License.
*/
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <topi/elemwise.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <topi/generic/injective.h>
#include <tvm/build_module.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
*
* \return A Tensor whose op member is a broadcast operation
*/
-inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
+inline tvm::top::Tensor broadcast_to(const tvm::top::Tensor& t,
const tvm::Array<tvm::PrimExpr>& output_shape,
std::string name = "T_broadcast_to",
std::string tag = kBroadcast) {
auto l = [&](tvm::Array<tvm::Var> ovars) {
return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
};
- return tvm::compute(
+ return tvm::top::compute(
tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
l,
name,
const tvm::PrimExpr& b) { \
ComputeRule; \
} \
- inline tvm::Tensor Name(const tvm::Tensor& A, \
- const tvm::Tensor& B, \
- std::string name = "T_" #Name, \
- std::string tag = kBroadcast) { \
+ inline tvm::top::Tensor Name(const tvm::top::Tensor& A, \
+ const tvm::top::Tensor& B, \
+ std::string name = "T_" #Name, \
+ std::string tag = kBroadcast) { \
auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
return detail::WithBroadcast(l, A, B, name, tag); \
} \
- inline tvm::Tensor Name(const tvm::Tensor& A, \
+ inline tvm::top::Tensor Name(const tvm::top::Tensor& A, \
const tvm::PrimExpr& B, \
std::string name = "T_" #Name, \
std::string tag = kElementWise) { \
- auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
- return compute(A->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
+ auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
+ return tvm::top::compute(A->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
return l(A(i), B); \
}, name, tag); \
} \
- inline tvm::Tensor Name(const tvm::PrimExpr& A, \
- const tvm::Tensor& B, \
- std::string name = "T_" #Name, \
- std::string tag = kElementWise) { \
+ inline tvm::top::Tensor Name(const tvm::PrimExpr& A, \
+ const tvm::top::Tensor& B, \
+ std::string name = "T_" #Name, \
+ std::string tag = kElementWise) { \
auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
- return compute(B->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
+ return tvm::top::compute(B->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
return l(A, B(i)); \
}, name, tag); \
}
#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \
- inline tvm::Tensor Name(const tvm::Tensor& A, \
- const tvm::Tensor& B) { \
+ inline tvm::top::Tensor Name(const tvm::top::Tensor& A, \
+ const tvm::top::Tensor& B) { \
return topi::OpName(A, B); \
} \
- inline tvm::Tensor Name(const tvm::PrimExpr& A, \
- const tvm::Tensor& B) { \
+ inline tvm::top::Tensor Name(const tvm::PrimExpr& A, \
+ const tvm::top::Tensor& B) { \
return topi::OpName(A, B); \
} \
- inline tvm::Tensor Name(const tvm::Tensor& A, \
- const tvm::PrimExpr& B) { \
+ inline tvm::top::Tensor Name(const tvm::top::Tensor& A, \
+ const tvm::PrimExpr& B) { \
return topi::OpName(A, B); \
}
#ifndef TOPI_CONTRIB_CUBLAS_H_
#define TOPI_CONTRIB_CUBLAS_H_
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "topi/detail/extern.h"
namespace topi {
namespace contrib {
using namespace tvm;
+using namespace tvm::top;
using namespace topi::detail;
/*!
* \brief Create an op that multiplies lhs and rhs with cuBLAS
#ifndef TOPI_CONTRIB_ROCBLAS_H_
#define TOPI_CONTRIB_ROCBLAS_H_
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "topi/detail/extern.h"
namespace topi {
namespace contrib {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Create an op that multiplies lhs and rhs with rocBLAS
*
#ifndef TOPI_CUDA_DENSE_H_
#define TOPI_CUDA_DENSE_H_
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
#include "topi/detail/array_utils.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace cuda {
/*!
*
* \return Tensor with shape [batch, out_dim]
*/
-inline tvm::Tensor dense_cuda(const Target& target,
- const tvm::Tensor& data,
- const tvm::Tensor& weight,
- const tvm::Tensor& bias,
+inline tvm::top::Tensor dense_cuda(const Target& target,
+ const tvm::top::Tensor& data,
+ const tvm::top::Tensor& weight,
+ const tvm::top::Tensor& bias,
const DataType& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
if (bias.defined()) {
- mm = tvm::compute({ batch, out_dim },
+ mm = tvm::top::compute({ batch, out_dim },
[&](Var i, Var j) {
return mm(i, j) + bias(j);
}, "tensor", kBroadcast);
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace cuda {
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
- *
+ *
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
- tvm::schedule::AutoInlineInjective(s);
+ tvm::top::AutoInlineInjective(s);
for (auto out : outs) {
schedule_injective_from_existing(s, out);
}
#ifndef TOPI_CUDA_NORMALIZATION_H_
#define TOPI_CUDA_NORMALIZATION_H_
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace cuda {
/*!
* \brief Create a CUDA schedule for LRN
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "topi/detail/array_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace cuda {
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace cuda {
/*!
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace cuda {
auto s = create_schedule(out_ops);
auto softmax = outs[0];
- tvm::Tensor max_elem;
- tvm::Tensor expsum;
- tvm::Tensor exp;
+ tvm::top::Tensor max_elem;
+ tvm::top::Tensor expsum;
+ tvm::top::Tensor exp;
bool has_exp = false;
auto tag = softmax->op.as<ComputeOpNode>()->tag;
#ifndef TOPI_DETAIL_ARRAY_UTILS_H_
#define TOPI_DETAIL_ARRAY_UTILS_H_
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
namespace topi {
namespace detail {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Search an array for a specific item
#include <string>
#include "tvm/ir_pass.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/expr_operator.h"
#include "topi/detail/constant_utils.h"
inline tvm::Array<tvm::PrimExpr> InputIndexFromBroadcast(
const tvm::Array<tvm::Var>& ovars,
- const tvm::Tensor& T,
+ const tvm::top::Tensor& T,
const std::deque<tvm::Var>& my_vars,
const std::deque<tvm::Var>& all_vars) {
tvm::Array<tvm::PrimExpr> ivars;
}
template <typename FBinaryExpr>
-inline tvm::Tensor WithBroadcast(FBinaryExpr op,
- const tvm::Tensor& A,
- const tvm::Tensor& B,
+inline tvm::top::Tensor WithBroadcast(FBinaryExpr op,
+ const tvm::top::Tensor& A,
+ const tvm::top::Tensor& B,
const std::string& name = "tensor",
const std::string& tag = "") {
auto bh = BroadcastShape(A->shape, B->shape);
return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
};
- return tvm::compute(
+ return tvm::top::compute(
tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
l,
name,
namespace topi {
namespace detail {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Test whether the given Expr is a constant integer
#ifndef TOPI_DETAIL_EXTERN_H_
#define TOPI_DETAIL_EXTERN_H_
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
#include <vector>
#include <string>
namespace topi {
namespace detail {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Construct a buffer to pass to an external function
#ifndef TOPI_DETAIL_FUSE_H_
#define TOPI_DETAIL_FUSE_H_
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
namespace topi {
namespace detail {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Fuse all of the given args
namespace topi {
namespace detail {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Get padding size for each side given padding height and width
#include <vector>
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/expr_operator.h"
namespace topi {
namespace detail {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Flatten the indices to 1D
namespace topi {
namespace detail {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Check whether input shape has dimension of size 0;
namespace topi {
using namespace tvm;
+using namespace tvm::top;
// Unary intrinsic operators
#define TOPI_DECLARE_UNARY_OP(OpName) \
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace generic {
/*!
}
auto s = create_schedule(out_ops);
auto x = outs[0];
- tvm::schedule::AutoInlineInjective(s);
+ tvm::top::AutoInlineInjective(s);
auto axis = s[x]->op.as<ComputeOpNode>()->axis;
if (axis.size() > 0) {
detail::Fuse(s[x], axis);
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
#include "injective.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace generic {
/*!
}
auto s = create_schedule(out_ops);
- tvm::schedule::AutoInlineInjective(s);
+ tvm::top::AutoInlineInjective(s);
for (auto out : outs) {
if (out->op->IsInstance<ExternOpNode>()) {
continue;
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace generic {
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
- *
+ *
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
- tvm::schedule::AutoInlineInjective(s);
+ tvm::top::AutoInlineInjective(s);
auto x = outs[0];
schedule_injective_from_existing(s, x);
#include "topi/elemwise.h"
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/expr_operator.h"
namespace topi {
namespace image {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Sample a point in a tensor using bilinear interpolation.
#include "topi/detail/constant_utils.h"
#include "tvm/ir.h"
#include "tvm/ir_pass.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/expr_operator.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace detail {
template <typename T>
* \return A Tensor whose op member is the relu operation
*/
template <typename T>
-inline tvm::Tensor relu(const tvm::Tensor& t,
+inline tvm::top::Tensor relu(const tvm::top::Tensor& t,
T threshold = static_cast<T>(0),
std::string name = "T_relu",
std::string tag = kElementWise) {
- return tvm::compute(
+ return tvm::top::compute(
t->shape,
[&](const tvm::Array<tvm::Var>& i) {
auto threshold_const = tvm::make_const(t->dtype, threshold);
*
* \return A Tensor whose op member is the leaky relu operation
*/
-inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
+inline tvm::top::Tensor leaky_relu(const tvm::top::Tensor& t,
double alpha = 0.1,
std::string name = "T_leaky_relu",
std::string tag = kElementWise) {
- return tvm::compute(
+ return tvm::top::compute(
t->shape,
[&](const tvm::Array<tvm::Var>& i) {
auto value = t(i);
*
* \return A Tensor whose op member is the parametric relu operation
*/
-inline tvm::Tensor prelu(const tvm::Tensor &x,
- const tvm::Tensor &slope,
+inline tvm::top::Tensor prelu(const tvm::top::Tensor &x,
+ const tvm::top::Tensor &slope,
const int axis = 1,
std::string name = "T_prelu",
std::string tag = kBroadcast) {
topi::detail::GetConstInt(x->shape[axis]))
<< "Wrong slope shape received.";
- return tvm::compute(x->shape,
+ return tvm::top::compute(x->shape,
[&](const tvm::Array<tvm::Var> &indices) {
auto xval = x(indices);
return tvm::ir::SelectNode::make(
*
*
*/
-inline tvm::Tensor pad(const tvm::Tensor& t,
+inline tvm::top::Tensor pad(const tvm::top::Tensor& t,
const tvm::Array<tvm::PrimExpr>& pad_before,
tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(),
PrimExpr pad_value = PrimExpr(),
}
return t(indices);
};
- return tvm::compute(output_shape, l, name, tag);
+ return tvm::top::compute(output_shape, l, name, tag);
}
/*!
* \return A Tensor whose op member is the 2-D convolution operation (NCHW
* layout)
*/
-inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
- const tvm::Tensor& W,
+inline tvm::top::Tensor conv2d_nchw(const tvm::top::Tensor& I,
+ const tvm::top::Tensor& W,
int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw),
{i, kh, kw});
};
- return tvm::compute(output_shape, l, name, tag);
+ return tvm::top::compute(output_shape, l, name, tag);
}
/*!
* \return A Tensor whose op member is the 2-D convolution operation
* (HWCN layout)
*/
-inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I,
- const tvm::Tensor& W,
+inline tvm::top::Tensor conv2d_hwcn(const tvm::top::Tensor& I,
+ const tvm::top::Tensor& W,
int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o),
{i, kh, kw});
};
- return tvm::compute(output_shape, l, name, tag);
+ return tvm::top::compute(output_shape, l, name, tag);
}
* \return A Tensor whose op member is the 2-D depthwise convolution operation
* (NCHW layout)
*/
-inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
- const tvm::Tensor& W,
+inline tvm::top::Tensor depthwise_conv2d_nchw(const tvm::top::Tensor& I,
+ const tvm::top::Tensor& W,
int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
{i, kh, kw});
};
- return tvm::compute(output_shape, l, name, tag);
+ return tvm::top::compute(output_shape, l, name, tag);
}
-inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
- const tvm::Tensor& W,
+inline tvm::top::Tensor depthwise_conv2d_nhwc(const tvm::top::Tensor& I,
+ const tvm::top::Tensor& W,
int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
{kh, kw, i});
};
- return tvm::compute(output_shape, l, name, tag);
+ return tvm::top::compute(output_shape, l, name, tag);
}
/*!
* \return A Tensor whose op member is the 2-D groupconvolution operation
* (NCHW layout)
*/
-inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
- const tvm::Tensor& W,
+inline tvm::top::Tensor group_conv2d_ngchw(const tvm::top::Tensor& I,
+ const tvm::top::Tensor& W,
int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw),
{i, kh, kw});
};
- return tvm::compute(output_shape, l, name, tag);
+ return tvm::top::compute(output_shape, l, name, tag);
}
} // namespace topi
#include <string>
#include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Creates an operation that calculates matrix multiplication in batch.
*
* \return Tensor with shape [batch, M, N]
*/
-inline tvm::Tensor batch_matmul(const tvm::Tensor& x,
- const tvm::Tensor& y) {
+inline tvm::top::Tensor batch_matmul(const tvm::top::Tensor& x,
+ const tvm::top::Tensor& y) {
CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data";
CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data";
auto N = y->shape[1];
auto k = tvm::reduce_axis(Range(0, K), "k");
- auto result = tvm::compute(
+ auto result = tvm::top::compute(
{ batch, M, N },
[&](Var b, Var i, Var j) {
return tvm::sum(x(b, i, k) * y(b, j, k), { k });
#include "topi/tags.h"
#include "topi/broadcast.h"
#include "topi/transform.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/expr_operator.h"
namespace topi {
* \param axis The axis to add the bias to.
* \return Tensor with shape [batch, in_dim]
*/
-inline tvm::Tensor bias_add(const tvm::Tensor& data, const tvm::Tensor& bias, int axis) {
+inline tvm::top::Tensor bias_add(const tvm::top::Tensor& data,
+ const tvm::top::Tensor& bias,
+ int axis) {
int data_ndim = data->shape.size();
if (axis < 0) {
axis += data_ndim;
#include <string>
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/ir_pass.h"
#include "topi/tags.h"
#include "topi/detail/constant_utils.h"
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Binarization and bit-packing along a certain axis.
*
* \return Output tensor with dtype uint32
*/
-inline tvm::Tensor binarize_pack(const tvm::Tensor& data,
+inline tvm::top::Tensor binarize_pack(const tvm::top::Tensor& data,
int axis,
std::string name = "PackedInput",
std::string tag = "binarize_pack") {
ishape[i]);
}
- return tvm::compute(
+ return tvm::top::compute(
oshape,
[&](const Array<Var>& indices) {
Array<PrimExpr> start_idx;
*
* \return Tensor with shape [batch, out_dim], dtype is float32
*/
-inline tvm::Tensor binary_dense(const tvm::Tensor& data,
- const tvm::Tensor& weight) {
+inline tvm::top::Tensor binary_dense(const tvm::top::Tensor& data,
+ const tvm::top::Tensor& weight) {
CHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight";
CHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data";
auto out_dim = weight->shape[0];
auto k = tvm::reduce_axis(Range(0, in_dim), "k");
- auto matmul = tvm::compute(
+ auto matmul = tvm::top::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return tvm::sum(popcount(data(i, k) ^ weight(j, k)), { k });
}, "tensor", "binary_dense");
- return tvm::compute(
+ return tvm::top::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return 32 * in_dim - 2.0f * matmul(i, j);
#include <string>
#include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Creates an operation that calculates data * weight^T + bias
*
* \return Tensor with shape [batch, out_dim]
*/
-inline tvm::Tensor dense(const tvm::Tensor& data,
- const tvm::Tensor& weight,
- const tvm::Tensor& bias,
+inline tvm::top::Tensor dense(const tvm::top::Tensor& data,
+ const tvm::top::Tensor& weight,
+ const tvm::top::Tensor& bias,
const DataType& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
auto out_dim = weight->shape[0];
auto k = tvm::reduce_axis(Range(0, in_dim), "k");
- auto matmul = tvm::compute(
+ auto matmul = tvm::top::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return tvm::sum(tvm::cast(out_dtype, data(i, k)) *
}, "tensor", "dense");
if (bias.defined()) {
- matmul = tvm::compute(
+ matmul = tvm::top::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return matmul(i, j) + tvm::cast(out_dtype, bias(j));
#include <string>
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/ir_pass.h"
#include "topi/tags.h"
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Create a new expression of the logical and of all
(x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1)));
}
- return tvm::compute(
+ return tvm::top::compute(
out_shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> not_zero;
#include "topi/tags.h"
#include "topi/detail/constant_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/expr_operator.h"
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions.
}
std::reverse(extra_shape.begin(), extra_shape.end());
- return tvm::compute(
+ return tvm::top::compute(
oshape, [&](Var i, Var j) {
PrimExpr idx = j;
std::vector<PrimExpr> index;
#include <string>
#include <algorithm>
#include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief L2 normalization inference operator
Tensor sum_value = topi::sum(dot_value, axis, true);
Tensor expand_sum = topi::broadcast_to(sum_value, input_shape);
return topi::divide(data,
- topi::sqrt(tvm::compute(expand_sum->shape,
+ topi::sqrt(tvm::top::compute(expand_sum->shape,
[&](const Array<Var>& i){
return (max(expand_sum(i), eps));
}, name, tag)));
#include <string>
#include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Local response normalization inference operator
auto rxs = tvm::reduce_axis(Range(0, size), "rxs");
Tensor sqr_sum;
if (axis == 1) {
- sqr_sum = tvm::compute(input_shape,
+ sqr_sum = tvm::top::compute(input_shape,
[&](Var i, Var l, Var j, Var k) {
return tvm::sum(pad_data(i, l + rxs, j, k) *
pad_data(i, l + rxs, j, k),
{rxs});
});
} else if (axis == 3) {
- sqr_sum = tvm::compute(input_shape,
+ sqr_sum = tvm::top::compute(input_shape,
[&](Var i, Var l, Var j, Var k) {
return tvm::sum(pad_data(i, l, j, k + rxs) *
pad_data(i, l, j, k + rxs),
{rxs});
});
}
- auto sqrt_sum_up = tvm::compute(
+ auto sqrt_sum_up = tvm::top::compute(
input_shape,
[&](Var i, Var j, Var k, Var l) {
return tvm::pow(bias +
#include <string>
#include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Scale and shift with NCHW order
const Tensor& shift,
std::string name = "ScaleShift",
std::string tag = kBroadcast) {
- return tvm::compute(
+ return tvm::top::compute(
x->shape,
[&](Var b, Var c, Var h, Var w) {
return x(b, c, h, w) * scale(c) + shift(w);
const Tensor& shift,
std::string name = "ScaleShift",
std::string tag = kBroadcast) {
- return tvm::compute(
+ return tvm::top::compute(
x->shape,
[&](Var b, Var h, Var w, Var c) {
return x(b, h, w, c) * scale(c) + shift(w);
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
/*! \brief Pooling type */
enum PoolType : int {
if (pool_type == kMaxPool) {
auto temp = do_pad ? pad(
x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
- return tvm::compute(out_shape, [&](const Array<Var>& output) {
+ return tvm::top::compute(out_shape, [&](const Array<Var>& output) {
Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
indices.Set(height_axis, output[height_axis] * stride_height + dheight);
auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
// TVM compute for summing the pooling window.
- auto pool_sum = tvm::compute(out_shape,
+ auto pool_sum = tvm::top::compute(out_shape,
[&](const Array<Var>& output) {
Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
}, "tensor", "pool_sum");
// TVM compute for dividing the reduced window sum by kernel size.
- return tvm::compute(out_shape,
+ return tvm::top::compute(out_shape,
[&](const Array<Var>& output) {
Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
auto mp_argmax =
- tvm::compute(out_shape,
+ tvm::top::compute(out_shape,
[&](const Array<Var>& inds) {
Array<PrimExpr> window_inds{inds.begin(), inds.end()};
window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
auto mp_inds = mp_argmax[0];
- return tvm::compute(
+ return tvm::top::compute(
x->shape,
[&](const Array<Var>& inds) {
Array<PrimExpr> pad_inds {inds.begin(), inds.end()};
} else if (pool_type == kAvgPool) {
auto windowh = tvm::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
auto windoww = tvm::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
- return tvm::compute(
+ return tvm::top::compute(
x->shape,
[&](const Array<Var>& inds) {
PrimExpr pad_h_idx = inds[height_axis] + pad_top;
out_shape.Set(width_axis, out_width);
if (pool_type == kMaxPool) {
- return tvm::compute(out_shape, [&](const Array<Var>& output) {
+ return tvm::top::compute(out_shape, [&](const Array<Var>& output) {
Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
auto i_start_h = start_index(output[height_axis], out_height, height);
return tvm::max(x(indices), { dheight, dwidth }); // NOLINT(*)
}, "tensor", "adaptive_pool_max");
} else if (pool_type == kAvgPool) {
- auto pool_sum = tvm::compute(out_shape, [&](const Array<Var>& output) {
+ auto pool_sum = tvm::top::compute(out_shape, [&](const Array<Var>& output) {
Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
auto i_start_h = start_index(output[height_axis], out_height, height);
return tvm::sum(x(indices), { dheight, dwidth });
}, "tensor", "adaptive_pool_sum");
- return tvm::compute(out_shape, [&](const Array<Var>& output) {
+ return tvm::top::compute(out_shape, [&](const Array<Var>& output) {
Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
auto i_start_h = start_index(output[height_axis], out_height, height);
if (pool_type == kMaxPool) {
auto temp = do_pad ? pad(
x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
- return tvm::compute(out_shape, [&](const Array<Var>& output) {
+ return tvm::top::compute(out_shape, [&](const Array<Var>& output) {
Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
// TVM compute for summing the pooling window.
- auto pool_sum = tvm::compute(out_shape,
+ auto pool_sum = tvm::top::compute(out_shape,
[&](const Array<Var>& output) {
Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
}, "tensor", "pool_sum");
// TVM compute for dividing the reduced window sum by kernel size.
- return tvm::compute(out_shape,
+ return tvm::top::compute(out_shape,
[&](const Array<Var>& output) {
Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
#include "topi/reduction.h"
#include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/expr_operator.h"
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Softmax activation
return exp(indices) / expsum(non_reduce_indices);
};
- auto max_elem = tvm::compute(reduced_shape, _compute_max);
- auto exp = tvm::compute(input_shape, [&](const Array<Var> &indices) {
+ auto max_elem = tvm::top::compute(reduced_shape, _compute_max);
+ auto exp = tvm::top::compute(input_shape, [&](const Array<Var> &indices) {
return _compute_exp(max_elem, indices);
});
- auto expsum = tvm::compute(reduced_shape, [&](const Array<Var> &indices) {
+ auto expsum = tvm::top::compute(reduced_shape, [&](const Array<Var> &indices) {
return _compute_expsum(exp, indices);
});
- return tvm::compute(input_shape, [&](const Array<Var> &indices) {
+ return tvm::top::compute(input_shape, [&](const Array<Var> &indices) {
return _normalize(exp, expsum, indices);
}, name, tag, attrs);
}
PrimExpr n = x->shape[1];
auto k = tvm::reduce_axis(Range(0, n), "k");
- auto max_elem = tvm::compute(
+ auto max_elem = tvm::top::compute(
{ m }, [&](Var i) {
return tvm::max(x(i, k), Array<IterVar>{ k }); });
k = tvm::reduce_axis(Range(0, n), "k");
- auto expsum = tvm::compute(
+ auto expsum = tvm::top::compute(
{ m }, [&](Var i) {
return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), { k }); });
- return tvm::compute(
+ return tvm::top::compute(
x->shape, [&](Var i, Var j) {
return x(i, j) - max_elem(i) - tvm::log(expsum(i));
}, name, tag);
namespace topi {
namespace nn {
using namespace tvm;
+using namespace tvm::top;
using namespace topi::image;
/*!
#include "topi/transform.h"
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/expr_operator.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
/*! \brief The operation to use for CommReduce */
using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis)>;
return func(data(eval_range), r_axes);
};
- return tvm::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
+ return tvm::top::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
}
/*!
return func({ idx, data(eval_range) }, reduce_axes, nullptr);
};
- auto temp_idx_val = tvm::compute(target_shape, compute,
+ auto temp_idx_val = tvm::top::compute(target_shape, compute,
data->op->name + "_red_temp", kCommReduceIdx);
auto temp_idx = temp_idx_val[0];
auto temp_val = temp_idx_val[1];
- return tvm::compute(
+ return tvm::top::compute(
target_shape,
[&temp_idx](const Array<Var>& indices) { return temp_idx(indices); },
data->op->name + "_red",
#ifndef TOPI_ROCM_DENSE_H_
#define TOPI_ROCM_DENSE_H_
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
#include "topi/detail/array_utils.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace rocm {
/*!
*
* \return Tensor with shape [batch, out_dim]
*/
-inline tvm::Tensor dense_rocm(const Target& target,
- const tvm::Tensor& data,
- const tvm::Tensor& weight,
- const tvm::Tensor& bias,
+inline tvm::top::Tensor dense_rocm(const Target& target,
+ const tvm::top::Tensor& data,
+ const tvm::top::Tensor& weight,
+ const tvm::top::Tensor& bias,
const DataType& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
if (bias.defined()) {
- mm = tvm::compute({ batch, out_dim },
+ mm = tvm::top::compute({ batch, out_dim },
[&](Var i, Var j) {
return mm(i, j) + bias(j);
}, "tensor", kBroadcast);
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
#include "topi/cuda/injective.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace rocm {
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
- *
+ *
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
#ifndef TOPI_ROCM_NORMALIZATION_H_
#define TOPI_ROCM_NORMALIZATION_H_
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace rocm {
/*!
* \brief Create a rocm schedule for LRN
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "topi/detail/array_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
#include "topi/cuda/pooling.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace rocm {
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
#include "topi/cuda/reduction.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace rocm {
/*!
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
#include "topi/cuda/softmax.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace rocm {
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
#include "topi/detail/tensor_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/expr_operator.h"
#include "tvm/data_layout.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
using namespace topi::detail;
/*!
*
* \return A Tensor whose op member is the matmul operation
*/
-inline tvm::Tensor matmul(const tvm::Tensor& A,
- const tvm::Tensor& B,
+inline tvm::top::Tensor matmul(const tvm::top::Tensor& A,
+ const tvm::top::Tensor& B,
bool trans_a = false,
bool trans_b = false,
std::string name = "T_matmul",
return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]),
{k});
};
- return tvm::compute(output_shape, l, name, tag);
+ return tvm::top::compute(output_shape, l, name, tag);
}
/*!
* \return A Tensor computing the result
*/
inline Tensor tensordot(const Tensor& A,
- const tvm::Tensor& B,
+ const tvm::top::Tensor& B,
int axes = 2,
std::string name = "T_tensordot",
std::string tag = kMatMul) {
* \return A Tensor computing the result
*/
inline Tensor tensordot(const Tensor& A,
- const tvm::Tensor& B,
+ const tvm::top::Tensor& B,
Array<PrimExpr> A_axes,
Array<PrimExpr> B_axes,
std::string name = "T_tensordot",
#include "topi/reduction.h"
#include "topi/tags.h"
#include "topi/transform.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/expr_operator.h"
namespace topi {
namespace vision {
using namespace tvm;
+using namespace tvm::top;
/*!
* \brief Reorg operation
int w_in = GetConstInt(input_shape[3]);
int out_c = c_in / (stride * stride);
- auto out = tvm::compute(input_shape,
+ auto out = tvm::top::compute(input_shape,
[&](Var b, Var k, Var j, Var i) {
return data(b * stride * stride,
indexmod(k, out_c) * stride * stride,
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace x86 {
/*!
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace x86 {
/*!
auto axis = s[x]->op.as<ComputeOpNode>()->axis;
if (auto_inline) {
- tvm::schedule::AutoInlineInjective(s);
+ tvm::top::AutoInlineInjective(s);
if (axis.size() > 0) {
detail::Fuse(s[x], axis);
}
#include "topi/tags.h"
#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
+using namespace tvm::top;
namespace x86 {
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
- *
+ *
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
- tvm::schedule::AutoInlineInjective(s);
+ tvm::top::AutoInlineInjective(s);
auto x = outs[0];
schedule_injective_from_existing(s, x);
inline bool IsTensorType(TVMArgValue arg) {
return (arg.type_code() == kTVMObjectHandle &&
static_cast<Object*>(
- arg.value().v_handle)->IsInstance<tvm::TensorNode>());
+ arg.value().v_handle)->IsInstance<tvm::top::TensorNode>());
}
bool lhs_is_tensor = IsTensorType(args[0]); \
bool rhs_is_tensor = IsTensorType(args[1]); \
if (lhs_is_tensor && rhs_is_tensor) { \
- *rv = Op(args[0].operator tvm::Tensor(), \
- args[1].operator tvm::Tensor()); \
+ *rv = Op(args[0].operator tvm::top::Tensor(), \
+ args[1].operator tvm::top::Tensor()); \
} else if (!lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::PrimExpr(), \
- args[1].operator tvm::Tensor()); \
+ args[1].operator tvm::top::Tensor()); \
} else if (lhs_is_tensor && !rhs_is_tensor) { \
- *rv = Op(args[0].operator tvm::Tensor(), \
+ *rv = Op(args[0].operator tvm::top::Tensor(), \
args[1].operator tvm::PrimExpr()); \
} else if (!lhs_is_tensor && !rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::PrimExpr(), \
/*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function<
- tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>;
+ tvm::top::Schedule(const tvm::Target& target, const tvm::Array<tvm::top::Tensor>& outs)>;
/*!
* \brief Helper function for registering generic functions matching the
/*! \brief Builder function for instantiating schedules from existing schedules. */
using FTVMScheduleFromExistingBuilder = std::function<
- tvm::Schedule(tvm::Schedule sch, const tvm::Tensor& out)>;
+ tvm::top::Schedule(tvm::top::Schedule sch, const tvm::top::Tensor& out)>;
/*!
* \brief Helper function for registering generic functions matching the
topi::cuda::schedule_injective_from_existing));
/*! \brief Builder function for instantiating dense ops. */
-using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
- const tvm::Tensor& data,
- const tvm::Tensor& weight,
- const tvm::Tensor& bias,
+using FTVMDenseOpBuilder = std::function<tvm::top::Tensor(const Target& target,
+ const tvm::top::Tensor& data,
+ const tvm::top::Tensor& weight,
+ const tvm::top::Tensor& bias,
const DataType& out_dtype)>;
/*!
TVM_REGISTER_GENERIC_FUNC(dense)
.set_default(WrapDenseOp([](const Target& target,
- const tvm::Tensor& data,
- const tvm::Tensor& weight,
- const tvm::Tensor& bias,
+ const tvm::top::Tensor& data,
+ const tvm::top::Tensor& weight,
+ const tvm::top::Tensor& bias,
const DataType& out_dtype) {
return topi::nn::dense(data, weight, bias, out_dtype);
}))