[REFACTOR] top - namespace for Tensor Operation DSL (#4727)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 16 Jan 2020 23:23:54 +0000 (15:23 -0800)
committerGitHub <noreply@github.com>
Thu, 16 Jan 2020 23:23:54 +0000 (15:23 -0800)
* [REFACTOR] introduce top - Tensor Operation DSL.

Historically we put Tensor, Schedule and compute under the root tvm namespace.
This is no longer a good idea as the project's scope grows larger
than the tensor operation DSL.

This PR introduces top -- a namespace for tensor operational
DSL concepts such as schedule, tensor, compute.
We moved the related files to the new top subfolder.

* Move relevant files into include/tvm/top and src/top

152 files changed:
CMakeLists.txt
include/tvm/arith/bound.h
include/tvm/build_module.h
include/tvm/ir_pass.h
include/tvm/lowered_func.h
include/tvm/operation.h [deleted file]
include/tvm/packed_func_ext.h
include/tvm/relay/op_attr_types.h
include/tvm/schedule.h [deleted file]
include/tvm/schedule_pass.h [deleted file]
include/tvm/tensor.h [deleted file]
include/tvm/tensor_intrin.h [deleted file]
include/tvm/top/operation.h [new file with mode: 0644]
include/tvm/top/schedule.h [new file with mode: 0644]
include/tvm/top/schedule_pass.h [new file with mode: 0644]
include/tvm/top/tensor.h [new file with mode: 0644]
include/tvm/top/tensor_intrin.h [new file with mode: 0644]
src/README.md
src/api/api_arith.cc
src/api/api_base.cc
src/api/api_lang.cc
src/api/api_pass.cc
src/api/api_schedule.cc
src/api/api_test.cc
src/arith/domain_touched.cc
src/codegen/build_module.cc
src/contrib/hybrid/codegen_hybrid.h
src/lang/tensor.cc [deleted file]
src/op/compute_op.cc [deleted file]
src/op/compute_op.h [deleted file]
src/op/cross_thread_reduction.cc [deleted file]
src/op/extern_op.cc [deleted file]
src/op/hybrid_op.cc [deleted file]
src/op/hybrid_op.h [deleted file]
src/op/op_util.cc [deleted file]
src/op/op_util.h [deleted file]
src/op/placeholder_op.cc [deleted file]
src/op/scan_op.cc [deleted file]
src/op/tensor_compute_op.cc [deleted file]
src/op/tensorize.cc [deleted file]
src/pass/inject_prefetch.cc
src/pass/storage_flatten.cc
src/pass/tensor_core.cc
src/pass/verify_compact_buffer.cc
src/relay/backend/compile_engine.cc
src/relay/backend/compile_engine.h
src/relay/backend/utils.h
src/relay/backend/vm/compiler.cc
src/relay/op/annotation/annotation.cc
src/relay/op/debug.cc
src/relay/op/memory/memory.cc
src/relay/op/nn/nn.cc
src/relay/op/nn/pad.cc
src/relay/op/nn/pooling.cc
src/relay/op/tensor/binary.cc
src/relay/op/tensor/reduce.cc
src/relay/op/tensor/transform.cc
src/relay/op/tensor/unary.cc
src/relay/op/vision/yolo.cc
src/relay/pass/alter_op_layout.cc
src/relay/pass/convert_layout.cc
src/relay/pass/gradient.cc
src/relay/pass/legalize.cc
src/schedule/auto_inline_elem_wise.cc [deleted file]
src/schedule/bound.cc [deleted file]
src/schedule/graph.cc [deleted file]
src/schedule/graph.h [deleted file]
src/schedule/message_passing.cc [deleted file]
src/schedule/message_passing.h [deleted file]
src/schedule/schedule_dataflow_rewrite.cc [deleted file]
src/schedule/schedule_lang.cc [deleted file]
src/schedule/schedule_ops.cc [deleted file]
src/top/operation/compute_op.cc [new file with mode: 0644]
src/top/operation/compute_op.h [new file with mode: 0644]
src/top/operation/cross_thread_reduction.cc [new file with mode: 0644]
src/top/operation/extern_op.cc [new file with mode: 0644]
src/top/operation/hybrid_op.cc [new file with mode: 0644]
src/top/operation/hybrid_op.h [new file with mode: 0644]
src/top/operation/op_util.cc [new file with mode: 0644]
src/top/operation/op_util.h [new file with mode: 0644]
src/top/operation/placeholder_op.cc [new file with mode: 0644]
src/top/operation/scan_op.cc [new file with mode: 0644]
src/top/operation/tensor_compute_op.cc [new file with mode: 0644]
src/top/operation/tensorize.cc [new file with mode: 0644]
src/top/schedule/auto_inline_elem_wise.cc [new file with mode: 0644]
src/top/schedule/bound.cc [new file with mode: 0644]
src/top/schedule/graph.cc [new file with mode: 0644]
src/top/schedule/graph.h [new file with mode: 0644]
src/top/schedule/message_passing.cc [new file with mode: 0644]
src/top/schedule/message_passing.h [new file with mode: 0644]
src/top/schedule/schedule_dataflow_rewrite.cc [new file with mode: 0644]
src/top/schedule/schedule_lang.cc [new file with mode: 0644]
src/top/schedule/schedule_ops.cc [new file with mode: 0644]
src/top/tensor.cc [new file with mode: 0644]
tests/cpp/build_module_test.cc
tests/cpp/expr_test.cc
tests/cpp/ir_simplify_test.cc
tests/cpp/relay_build_module_test.cc
tests/cpp/relay_pass_type_infer_test.cc
tests/cpp/relay_transform_sequential.cc
tests/cpp/simple_passes_test.cc
tests/cpp/tensor_test.cc
tests/cpp/topi_ewise_test.cc
tests/cpp/utvm_runtime_standalone_test.cc
topi/include/topi/broadcast.h
topi/include/topi/contrib/cublas.h
topi/include/topi/contrib/rocblas.h
topi/include/topi/cuda/dense.h
topi/include/topi/cuda/injective.h
topi/include/topi/cuda/normalization.h
topi/include/topi/cuda/pooling.h
topi/include/topi/cuda/reduction.h
topi/include/topi/cuda/softmax.h
topi/include/topi/detail/array_utils.h
topi/include/topi/detail/broadcast.h
topi/include/topi/detail/constant_utils.h
topi/include/topi/detail/extern.h
topi/include/topi/detail/fuse.h
topi/include/topi/detail/pad_utils.h
topi/include/topi/detail/ravel_unravel.h
topi/include/topi/detail/tensor_utils.h
topi/include/topi/elemwise.h
topi/include/topi/generic/default.h
topi/include/topi/generic/extern.h
topi/include/topi/generic/injective.h
topi/include/topi/image/resize.h
topi/include/topi/nn.h
topi/include/topi/nn/batch_matmul.h
topi/include/topi/nn/bias_add.h
topi/include/topi/nn/bnn.h
topi/include/topi/nn/dense.h
topi/include/topi/nn/dilate.h
topi/include/topi/nn/flatten.h
topi/include/topi/nn/l2_normalize.h
topi/include/topi/nn/local_response_norm.h
topi/include/topi/nn/mapping.h
topi/include/topi/nn/pooling.h
topi/include/topi/nn/softmax.h
topi/include/topi/nn/upsampling.h
topi/include/topi/reduction.h
topi/include/topi/rocm/dense.h
topi/include/topi/rocm/injective.h
topi/include/topi/rocm/normalization.h
topi/include/topi/rocm/pooling.h
topi/include/topi/rocm/reduction.h
topi/include/topi/rocm/softmax.h
topi/include/topi/transform.h
topi/include/topi/vision/reorg.h
topi/include/topi/x86/bnn.h
topi/include/topi/x86/default.h
topi/include/topi/x86/injective.h
topi/src/topi.cc

index f48e7d18d094f051629a2f91048a5e215d0fce85..a62173e13e785efa94c2be208c20e7c0770d35b7 100644 (file)
@@ -124,20 +124,24 @@ assign_source_group("Source" ${GROUP_SOURCE})
 assign_source_group("Include" ${GROUP_INCLUDE})
 
 # Source file lists
-file(GLOB COMPILER_SRCS
+file(GLOB_RECURSE COMPILER_SRCS
     src/node/*.cc
     src/ir/*.cc
     src/target/*.cc
-    src/api/*.cc
     src/arith/*.cc
+    src/top/*.cc
+    src/api/*.cc
     src/autotvm/*.cc
-    src/codegen/*.cc
     src/lang/*.cc
     src/pass/*.cc
-    src/op/*.cc
-    src/schedule/*.cc
     )
 
+file(GLOB CODEGEN_SRCS
+    src/codegen/*.cc
+    )
+
+list(APPEND COMPILER_SRCS ${CODEGEN_SRCS})
+
 file(GLOB_RECURSE RELAY_OP_SRCS
     src/relay/op/*.cc
     )
index 73c0733c41d34a5030cd88d775a9702ad1a8bf6b..e0695481614859b6ba949f91179c1028d91c9a10 100644 (file)
@@ -32,7 +32,9 @@
 
 namespace tvm {
 // forward delcare Tensor
+namespace top {
 class Tensor;
+}
 namespace arith {
 
 /*!
@@ -75,7 +77,10 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
  * \param consider_provides If provides (write) are considered.
  * \return The domain that covers all the calls or provides within the given statement.
  */
-Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
+Domain DomainTouched(Stmt body,
+                     const top::Tensor &tensor,
+                     bool consider_calls,
+                     bool consider_provides);
 
 }  // namespace arith
 }  // namespace tvm
index 29903786a2b3948f3dc66ffc9feb42e7f1c6b56f..4e6e51744b6c9a186a951e664c7b5ef03b4de4d2 100644 (file)
@@ -26,6 +26,8 @@
 
 #include <tvm/target/target.h>
 #include <tvm/support/with.h>
+#include <tvm/top/schedule_pass.h>
+
 #include <string>
 #include <vector>
 #include <utility>
@@ -33,7 +35,7 @@
 #include <unordered_set>
 
 #include "runtime/packed_func.h"
-#include "schedule_pass.h"
+
 #include "lowered_func.h"
 
 namespace tvm {
@@ -172,10 +174,10 @@ class BuildConfig : public ::tvm::ObjectRef {
 * \param config The build configuration.
 * \return The lowered function.
 */
-TVM_DLL Array<LoweredFunc> lower(Schedule sch,
-                                 const Array<Tensor>& args,
+TVM_DLL Array<LoweredFunc> lower(top::Schedule sch,
+                                 const Array<top::Tensor>& args,
                                  const std::string& name,
-                                 const std::unordered_map<Tensor, Buffer>& binds,
+                                 const std::unordered_map<top::Tensor, Buffer>& binds,
                                  const BuildConfig& config);
 /*!
 * \brief Split host/device function and running necessary pass before build
index 891d3245cff75a69b88b1b06a0836f0421da6b42..bf444265b0783961a77212dd4fb3e055318a9334 100644 (file)
 #ifndef TVM_IR_PASS_H_
 #define TVM_IR_PASS_H_
 
+#include <tvm/top/schedule.h>
+
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
 #include <string>
 #include "expr.h"
 #include "buffer.h"
-#include "schedule.h"
 #include "lowered_func.h"
 
 namespace tvm {
@@ -203,7 +204,7 @@ Stmt Inline(Stmt stmt,
  * \return Transformed stmt.
  */
 Stmt StorageFlatten(Stmt stmt,
-                    Map<Tensor, Buffer> extern_buffer,
+                    Map<top::Tensor, Buffer> extern_buffer,
                     int cache_line_size,
                     bool create_bound_attribute = false);
 
@@ -217,8 +218,8 @@ Stmt StorageFlatten(Stmt stmt,
  * \return Transformed stmt.
  */
 Stmt RewriteForTensorCore(Stmt stmt,
-                          Schedule schedule,
-                          Map<Tensor, Buffer> extern_buffer);
+                          top::Schedule schedule,
+                          Map<top::Tensor, Buffer> extern_buffer);
 
 /*!
  * \brief Verify if there is any argument bound to compact buffer.
index 2b643d75f08b504953693f6a52b310bb06431153..b0350ae8b4dc211eedafa7f59c7f78aab52b5cc6 100644 (file)
 #ifndef TVM_LOWERED_FUNC_H_
 #define TVM_LOWERED_FUNC_H_
 
+#include <tvm/top/tensor.h>
+
 #include <string>
 
 #include "expr.h"
-#include "tensor.h"
 #include "tvm/node/container.h"
 
 namespace tvm {
diff --git a/include/tvm/operation.h b/include/tvm/operation.h
deleted file mode 100644 (file)
index 54c98a3..0000000
+++ /dev/null
@@ -1,659 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/operation.h
- * \brief Operation node can generate one or multiple Tensors
- */
-#ifndef TVM_OPERATION_H_
-#define TVM_OPERATION_H_
-
-#include <tvm/arith/analyzer.h>
-
-#include <string>
-#include <vector>
-#include <unordered_map>
-
-#include "expr.h"
-#include "expr_operator.h"
-#include "tensor.h"
-#include "schedule.h"
-#include "buffer.h"
-
-namespace tvm {
-
-using arith::IntSet;
-
-/*!
- * \brief Temporary data structure to store union
- *  of bounds of each axis of Tensor.
- */
-struct TensorDom {
-  // constructor
-  explicit TensorDom(int ndim)
-      : data(ndim) {}
-  /*! \brief The domain data */
-  std::vector<std::vector<IntSet> > data;
-};
-
-/*!
- * \brief Base class of all operation nodes
- */
-class OperationNode : public ir::FunctionBaseNode {
- public:
-  /*! \brief optional name of the operation */
-  std::string name;
-  /*! \brief optional tag of the operation */
-  std::string tag;
-  /*! \brief additional attributes of the operation*/
-  Map<std::string, ObjectRef> attrs;
-  /*! \return name of the operation */
-  const std::string& func_name() const final {
-    return name;
-  }
-  /*!
-   * \return The list of iteration variable at root
-   * \note root_iter_vars decides the shape of the outputs.
-   */
-  virtual Array<IterVar> root_iter_vars() const = 0;
-  /*!
-   * \brief Get data type. i-th output tensor.
-   * \param i The output index.
-   * \return type of i-th output.
-   */
-  virtual DataType output_dtype(size_t i) const = 0;
-  /*!
-   * \brief Get shape of i-th output tensor.
-   * \param i The output index.
-   * \return shape of i-th output.
-   */
-  virtual Array<PrimExpr> output_shape(size_t i) const = 0;
-  /*!
-   * \brief List all the input Tensors.
-   * \return List of input tensors.
-   */
-  virtual Array<Tensor> InputTensors() const = 0;
-  /*!
-   * \brief Replace the input of the operation by pattern specified by rmap.
-   *
-   * \param self The reference to self.
-   * \param rmap The replacement map.
-   * \return self if nothing is replaced, otherwise return replaced op.
-   */
-  virtual Operation ReplaceInputs(
-      const Operation& self,
-      const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
-  /*!
-   * \brief Propagate the bounds to inputs
-   * \param self The reference to self.
-   * \param analyzer The analyzer to be used in the function.
-   * \param dom_map the domain map of Variables(corresponds to root_iter_vars)
-   * \param out_dom_map The output domain.
-   *  The function is only asked to fill the bounds for Tensors that
-   *  is already in the out_dom_map
-   */
-  virtual void PropBoundToInputs(
-      const Operation& self,
-      arith::Analyzer* analyzer,
-      const std::unordered_map<const VarNode*, IntSet>& dom_map,
-      std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
-  /*!
-   * \brief Gather the bound from output tensor.
-   *  Set the range of each root_iter_vars in the op to out_dom_map
-   *
-   * \param self The reference to self.
-   * \param tensor_dom Domain map of Tensor->access set of each dimension.
-   * \param out_dom_map The output domain map of each IterVar to be setted.
-   */
-  virtual void GatherBound(
-      const Operation& self,
-      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
-      std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
-  /*!
-   * \brief Build the Realize statement that realizes
-   *   the op's output tensors.
-   * \param stage the op's stage.
-   * \param realize_map The realization domain map of the operators.
-   * \param body The body that is going to get
-   * \return A realization statement that wraps body.
-   */
-  virtual Stmt BuildRealize(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& realize_map,
-      const Stmt& body) const = 0;
-  /*!
-   * \brief Build the statement that provide the output tensors.
-   * \param stage The schedule stage of the op.
-   * \param dom_map The domain map of all iteration domains.
-   * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
-   * \return A statement that add production and wraps consumer.
-   */
-  virtual Stmt BuildProvide(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& dom_map,
-      bool debug_keep_trivial_loop) const = 0;
-
-  static constexpr const char* _type_key = "Operation";
-
-  TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object);
-};
-
-/*!
- * \brief A placeholder op represents an input placeholder.
- */
-class PlaceholderOpNode : public OperationNode {
- public:
-  /*! \brief The shape of the input */
-  Array<PrimExpr> shape;
-  /*! \brief The data type of the input. */
-  DataType dtype;
-  // override behavior.
-  int num_outputs() const final;
-  Array<IterVar> root_iter_vars() const final;
-  DataType output_dtype(size_t i) const final;
-  Array<PrimExpr> output_shape(size_t i) const final;
-  Array<Tensor> InputTensors() const final;
-  Operation ReplaceInputs(
-      const Operation& self,
-      const std::unordered_map<Tensor, Tensor>& rmap) const final;
-  void PropBoundToInputs(
-      const Operation& self,
-      arith::Analyzer* analyzer,
-      const std::unordered_map<const VarNode*, IntSet>& dom_map,
-      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
-  void GatherBound(
-      const Operation& self,
-      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
-      std::unordered_map<IterVar, Range>* out_dom_map) const final;
-  Stmt BuildRealize(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& realize_map,
-      const Stmt& body) const final;
-  Stmt BuildProvide(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& dom_map,
-      bool debug_keep_trivial_loop) const final;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("tag", &tag);
-    v->Visit("attrs", &attrs);
-    v->Visit("shape", &shape);
-    v->Visit("dtype", &dtype);
-  }
-  static Operation make(std::string name,
-                        Array<PrimExpr> shape,
-                        DataType dtype);
-
-  static constexpr const char* _type_key = "PlaceholderOp";
-  TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
-};
-
-/*!
- * \brief A Compute op that compute a tensor on certain domain.
- * This is the base class for ComputeOp (operating on a scalar at a time) and
- * TensorComputeOp (operating on a TensorSlice at a time)
- */
-class TVM_DLL BaseComputeOpNode : public OperationNode {
- public:
-  /*! \brief IterVar on each axis */
-  Array<IterVar> axis;
-  /*! \brief IterVar on each reduction axis, if the body is a Reduce */
-  Array<IterVar> reduce_axis;
-  // override functions
-  Array<IterVar> root_iter_vars() const final;
-  Array<PrimExpr> output_shape(size_t idx) const final;
-  void GatherBound(
-          const Operation& self,
-          const std::unordered_map<Tensor, TensorDom>& tensor_dom,
-          std::unordered_map<IterVar, Range>* out_dom_map) const final;
-  Stmt BuildRealize(
-          const Stage& stage,
-          const std::unordered_map<IterVar, Range>& realize_map,
-          const Stmt& body) const final;
-  virtual size_t num_schedulable_dims() const = 0;
-
-  static constexpr const char* _type_key = "BaseComputeOp";
-  TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
-};
-
-
-/*!
- * \brief A Compute op that compute a tensor on certain domain.
- */
-class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
- public:
-  /*! \brief the compute expression */
-  Array<PrimExpr> body;
-  /*! \brief constructor */
-  ComputeOpNode() {}
-  // override functions
-  int num_outputs() const final;
-  DataType output_dtype(size_t i) const final;
-  Array<Tensor> InputTensors() const final;
-  Operation ReplaceInputs(
-      const Operation& self,
-      const std::unordered_map<Tensor, Tensor>& rmap) const final;
-  void PropBoundToInputs(
-      const Operation& self,
-      arith::Analyzer* analyzer,
-      const std::unordered_map<const VarNode*, IntSet>& dom_map,
-      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
-  Stmt BuildProvide(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& dom_map,
-      bool debug_keep_trivial_loop) const final;
-  size_t num_schedulable_dims() const final;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("tag", &tag);
-    v->Visit("attrs", &attrs);
-    v->Visit("axis", &axis);
-    v->Visit("reduce_axis", &reduce_axis);
-    v->Visit("body", &body);
-  }
-  static Operation make(std::string name,
-                        std::string tag,
-                        Map<std::string, ObjectRef> attrs,
-                        Array<IterVar> axis,
-                        Array<PrimExpr> body);
-
-  static constexpr const char* _type_key = "ComputeOp";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
-};
-
-/*!
- * \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
- */
-class TensorComputeOpNode : public BaseComputeOpNode {
- public:
-  /*! \brief number of axes that can be scheduled */
-  int schedulable_ndim;
-  /*! \brief TensorIntrin used to compute */
-  TensorIntrin intrin;
-  /*! \brief input tensors of intrin */
-  Array<Tensor> inputs;
-  /*! \brief region of input tensors */
-  Array<Region> input_regions;
-  /*! \brief scalar expression inputs */
-  Array<PrimExpr> scalar_inputs;
-  /*! \brief constructor */
-  TensorComputeOpNode() {}
-  // override functions
-  int num_outputs() const final;
-  DataType output_dtype(size_t i) const final;
-  Array<Tensor> InputTensors() const final;
-  Operation ReplaceInputs(
-      const Operation& self,
-      const std::unordered_map<Tensor, Tensor>& rmap) const final;
-  void PropBoundToInputs(
-      const Operation& self,
-      arith::Analyzer* analyzer,
-      const std::unordered_map<const VarNode*, IntSet>& dom_map,
-      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
-  Stmt BuildProvide(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& dom_map,
-      bool debug_keep_trivial_loop) const final;
-  size_t num_schedulable_dims() const final;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("tag", &tag);
-    v->Visit("axis", &axis);
-    v->Visit("reduce_axis", &reduce_axis);
-    v->Visit("schedulable_ndim", &schedulable_ndim);
-    v->Visit("intrin", &intrin);
-    v->Visit("inputs", &inputs);
-    v->Visit("input_regions", &input_regions);
-    v->Visit("scalar_inputs", &scalar_inputs);
-  }
-  static Operation make(std::string name,
-                        std::string tag,
-                        Array<IterVar> axis,
-                        Array<IterVar> reduce_axis,
-                        int schedulable_ndim,
-                        TensorIntrin intrin,
-                        Array<Tensor> tensors,
-                        Array<Region> regions,
-                        Array<PrimExpr> scalar_inputs);
-
-  static constexpr const char* _type_key = "TensorComputeOp";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
-};
-
-/*!
- * \brief Symbolic scan.
- */
-class ScanOpNode : public OperationNode {
- public:
-  /*! \brief IterVar to scan over */
-  IterVar scan_axis;
-  /*! \brief the initialization tensors */
-  Array<Tensor> init;
-  /*! \brief the update function represented by tensor */
-  Array<Tensor> update;
-  /*! \brief The placeholder to refer as states in update. */
-  Array<Tensor> state_placeholder;
-  /*!
-   * \brief the inputs to the scan, these are optionally provided
-   *  But they can be helpful to provide hints to speedup get of scan body.
-   */
-  Array<Tensor> inputs;
-  /*!
-   * \brief Spatial axis to indicate spatial dimension of each output.
-   *  They corresponds to flattened spatial axis of the outputs.
-   *
-   *  [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
-   *  These are auxiliary data structure for storing result of bound inference.
-   *  They do not corresponds to splittable iterations, thus the name comes
-   *  with underscore.
-   */
-  Array<IterVar> spatial_axis_;
-  /*! \brief constructor */
-  ScanOpNode() {}
-  // override behavior.
-  int num_outputs() const final;
-  Array<IterVar> root_iter_vars() const final;
-  DataType output_dtype(size_t i) const final;
-  Array<PrimExpr> output_shape(size_t i) const final;
-  Array<Tensor> InputTensors() const final;
-  Operation ReplaceInputs(
-      const Operation& self,
-      const std::unordered_map<Tensor, Tensor>& rmap) const final;
-  void PropBoundToInputs(
-      const Operation& self,
-      arith::Analyzer* analyzer,
-      const std::unordered_map<const VarNode*, IntSet>& dom_map,
-      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
-  void GatherBound(
-      const Operation& self,
-      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
-      std::unordered_map<IterVar, Range>* out_dom_map) const final;
-  Stmt BuildRealize(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& realize_map,
-      const Stmt& body) const final;
-  Stmt BuildProvide(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& dom_map,
-      bool debug_keep_trivial_loop) const final;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("tag", &tag);
-    v->Visit("attrs", &attrs);
-    v->Visit("scan_axis", &scan_axis);
-    v->Visit("init", &init);
-    v->Visit("update", &update);
-    v->Visit("state_placeholder", &state_placeholder);
-    v->Visit("inputs", &inputs);
-    v->Visit("spatial_axis_", &spatial_axis_);
-  }
-  static Operation make(std::string name,
-                        std::string tag,
-                        Map<std::string, ObjectRef> attrs,
-                        IterVar axis,
-                        Array<Tensor> init,
-                        Array<Tensor> update,
-                        Array<Tensor> state_placeholder,
-                        Array<Tensor> input);
-
-  static constexpr const char* _type_key = "ScanOp";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
-};
-
-/*!
- * \brief External computation that cannot be splitted.
- */
-class ExternOpNode : public OperationNode {
- public:
-  /*! \brief The input tensors */
-  Array<Tensor> inputs;
-  /*! \brief Symbolic placeholder representation of inputs */
-  Array<Buffer> input_placeholders;
-  /*! \brief Symbolic placeholder representation of outputs */
-  Array<Buffer> output_placeholders;
-  /*! \brief the statement that generates the computation. */
-  Stmt body;
-
-  /*! \brief constructor */
-  ExternOpNode() {}
-  // override functions
-  int num_outputs() const final;
-  Array<IterVar> root_iter_vars() const final;
-  DataType output_dtype(size_t i) const final;
-  Array<PrimExpr> output_shape(size_t i) const final;
-  Array<Tensor> InputTensors() const final;
-  Operation ReplaceInputs(
-      const Operation& self,
-      const std::unordered_map<Tensor, Tensor>& rmap) const final;
-  void PropBoundToInputs(
-      const Operation& self,
-      arith::Analyzer* analyzer,
-      const std::unordered_map<const VarNode*, IntSet>& dom_map,
-      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
-  void GatherBound(
-      const Operation& self,
-      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
-      std::unordered_map<IterVar, Range>* out_dom_map) const final;
-  Stmt BuildRealize(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& realize_map,
-      const Stmt& body) const final;
-  Stmt BuildProvide(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& dom_map,
-      bool debug_keep_trivial_loop) const final;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("tag", &tag);
-    v->Visit("attrs", &attrs);
-    v->Visit("inputs", &inputs);
-    v->Visit("input_placeholders", &input_placeholders);
-    v->Visit("output_placeholders", &output_placeholders);
-    v->Visit("body", &body);
-  }
-  TVM_DLL static Operation make(std::string name,
-                               std::string tag,
-                               Map<std::string, ObjectRef> attrs,
-                               Array<Tensor> inputs,
-                               Array<Buffer> input_placeholders,
-                               Array<Buffer> output_placeholders,
-                               Stmt body);
-
-  static constexpr const char* _type_key = "ExternOp";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
-};
-
-/*!
- * \brief A computation operator that generated by hybrid script.
- */
-class HybridOpNode : public OperationNode {
- public:
-  /*! \brief The input tensors */
-  Array<Tensor> inputs;
-  /*! \brief Symbolic placeholder representation of outputs */
-  Array<Tensor> outputs;
-  /*! \brief The axis of iterations */
-  Array<IterVar> axis;
-  /*! \brief the statement that generates the computation. This is
-   * slightly different from the body in ExternOpNode. All the output
-   * tensors keep its own name specified by users in the script.
-   * However, when compilation, these tensors will be placed by those
-   * actual output tensors. */
-  Stmt body;
-
-  /*! \brief constructor */
-  HybridOpNode() {}
-  // override functions
-  int num_outputs() const final;
-  Array<IterVar> root_iter_vars() const final;
-  DataType output_dtype(size_t i) const final;
-  Array<PrimExpr> output_shape(size_t i) const final;
-  Array<Tensor> InputTensors() const final;
-  Operation ReplaceInputs(
-      const Operation& self,
-      const std::unordered_map<Tensor, Tensor>& rmap) const final;
-  void PropBoundToInputs(
-      const Operation& self,
-      arith::Analyzer* analyzer,
-      const std::unordered_map<const VarNode*, IntSet>& dom_map,
-      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
-  void GatherBound(
-      const Operation& self,
-      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
-      std::unordered_map<IterVar, Range>* out_dom_map) const final;
-  Stmt BuildRealize(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& realize_map,
-      const Stmt& body) const final;
-  Stmt BuildProvide(
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& dom_map,
-      bool debug_keep_trivial_loop) const final;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("tag", &tag);
-    v->Visit("attrs", &attrs);
-    v->Visit("inputs", &inputs);
-    v->Visit("outputs", &outputs);
-    v->Visit("axis", &axis);
-    v->Visit("body", &body);
-  }
-  TVM_DLL static Operation make(std::string name,
-                                std::string tag,
-                                Map<std::string, ObjectRef> attrs,
-                                Array<Tensor> inputs,
-                                Array<Tensor> outputs,
-                                Stmt body);
-
-  static constexpr const char* _type_key = "HybridOp";
-  TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
-};
-
-/*! \brief The compute function to specify the input source of a Tensor */
-using FCompute = std::function<PrimExpr (const Array<Var>& i)>;
-
-/*! \brief The compute function to specify the inputs source of Tensors */
-using FBatchCompute = std::function<Array<PrimExpr> (const Array<Var>& i)>;
-
-/*!
- * \brief create a place holder tensor.
- * \param shape The shape of the tensor.
- * \param dtype the data type of the tensor.
- * \param name The name of the Tensor.
- */
-TVM_DLL Tensor placeholder(Array<PrimExpr> shape,
-                           DataType dtype = DataType::Float(32),
-                           std::string name = "placeholder");
-
-/*!
- * \brief Construct a new tensor by computing over shape,
- *  using the computation rule: result_tensor[axis] = fcompute(axis)
- * \param shape Shape of the tensor.
- * \param fcompute The compute function to create the tensor.
- * \param name The optional name of the tensor.
- * \param tag The optional tag of the tensor.
- * \param attrs Optional additional attributes of the compute.
- */
-TVM_DLL Tensor compute(Array<PrimExpr> shape,
-                       FCompute fcompute,
-                       std::string name = "tensor",
-                       std::string tag = "",
-                       Map<std::string, ObjectRef> attrs = {});
-
-/*!
- * \brief Construct a new tensor by computing over shape,
- *  using the computation rule: result_tensor[axis] = fcompute(axis)
- * \param shape Shape of the tensor.
- * \param fcompute The compute function to create the tensors.
- * \param name The optional name of the tensor.
- * \param tag The optional tag of the tensor.
- * \param attrs Optional additional attributes of the compute.
- */
-TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape,
-                              FBatchCompute fcompute,
-                              std::string name = "tensor",
-                              std::string tag = "",
-                              Map<std::string, ObjectRef> attrs = {});
-
-/*!
- * \brief Construct new tensors by scan.
- *
- * \param init The intialize tensor of first K steps.
- * \param update The update tensor indicated the updated result after each timestamp.
- * \param state_placeholder The placeholder for the states.
- * \param inputs The inputs to the scan body, this is optional,
- *    but recommended to provide concrete information about scan body.
- * \param name The optional name of the tensor.
- * \param tag The optional tag of the tensor.
- * \param attrs Optional additional attributes of the compute.
- */
-TVM_DLL Array<Tensor> scan(Array<Tensor> init,
-                           Array<Tensor> update,
-                           Array<Tensor> state_placeholder,
-                           Array<Tensor> inputs = Array<Tensor>(),
-                           std::string name = "scan",
-                           std::string tag = "",
-                           Map<std::string, ObjectRef> attrs = {});
-
-// same as compute, specialized for different fcompute function
-inline Tensor compute(Array<PrimExpr> shape,
-                      std::function<PrimExpr(Var)> f,
-                      std::string name = "tensor",
-                      std::string tag = "",
-                      Map<std::string, ObjectRef> attrs = {}) {
-  FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
-  return compute(shape, fc, name, tag, attrs);
-}
-inline Tensor compute(Array<PrimExpr> shape,
-                      std::function<PrimExpr(Var, Var)> f,
-                      std::string name = "tensor",
-                      std::string tag = "",
-                      Map<std::string, ObjectRef> attrs = {}) {
-  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
-  return compute(shape, fc, name, tag, attrs);
-}
-inline Tensor compute(Array<PrimExpr> shape,
-                      std::function<PrimExpr(Var, Var, Var)> f,
-                      std::string name = "tensor",
-                      std::string tag = "",
-                      Map<std::string, ObjectRef> attrs = {}) {
-  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
-  return  compute(shape, fc, name, tag, attrs);
-}
-inline Tensor compute(Array<PrimExpr> shape,
-                      std::function<PrimExpr(Var, Var, Var, Var)> f,
-                      std::string name = "tensor",
-                      std::string tag = "",
-                      Map<std::string, ObjectRef> attrs = {}) {
-  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
-  return compute(shape, fc, name, tag, attrs);
-}
-
-// inline function.
-inline const OperationNode* Operation::operator->() const {
-  return static_cast<const OperationNode*>(get());
-}
-}  // namespace tvm
-#endif  // TVM_OPERATION_H_
index cc380cdbd8aa7e3f435f95c6f7975a069e4a8567..f7b0d0820980ca97d3b12a7dc8ca446973cc76eb 100644 (file)
 #ifndef TVM_PACKED_FUNC_EXT_H_
 #define TVM_PACKED_FUNC_EXT_H_
 
+#include <tvm/top/tensor.h>
+
 #include <string>
 #include <memory>
 #include <limits>
 #include <type_traits>
 
 #include "expr.h"
-#include "tensor.h"
 #include "runtime/packed_func.h"
 
 namespace tvm {
@@ -116,8 +117,8 @@ inline TVMPODValue_::operator tvm::PrimExpr() const {
   if (ptr->IsInstance<IterVarNode>()) {
     return IterVar(ObjectPtr<Object>(ptr))->var;
   }
-  if (ptr->IsInstance<TensorNode>()) {
-    return Tensor(ObjectPtr<Object>(ptr))();
+  if (ptr->IsInstance<top::TensorNode>()) {
+    return top::Tensor(ObjectPtr<Object>(ptr))();
   }
   CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr))
       << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
index b6221e0ba8a5b2be33c9f7f603d80833b06b4811..385d6453fae9efa89b65fa4e0102f997833eed14 100644 (file)
@@ -24,8 +24,8 @@
 #ifndef TVM_RELAY_OP_ATTR_TYPES_H_
 #define TVM_RELAY_OP_ATTR_TYPES_H_
 
-#include <tvm/tensor.h>
-#include <tvm/schedule.h>
+#include <tvm/top/tensor.h>
+#include <tvm/top/schedule.h>
 #include <tvm/build_module.h>
 #include <tvm/relay/type.h>
 #include <tvm/relay/expr.h>
@@ -99,10 +99,10 @@ using TShapeDataDependant = bool;
  * \return The output compute description of the operator.
  */
 using FTVMCompute = runtime::TypedPackedFunc<
-  Array<Tensor>(const Attrs& attrs,
-                const Array<Tensor>& inputs,
-                const Type& out_type,
-                const Target& target)>;
+  Array<top::Tensor>(const Attrs& attrs,
+                     const Array<top::Tensor>& inputs,
+                     const Type& out_type,
+                     const Target& target)>;
 
 /*!
  * \brief Build the computation schedule for
@@ -114,9 +114,9 @@ using FTVMCompute = runtime::TypedPackedFunc<
  * \return schedule The computation schedule.
  */
 using FTVMSchedule = runtime::TypedPackedFunc<
-  Schedule(const Attrs& attrs,
-           const Array<Tensor>& outs,
-           const Target& target)>;
+  top::Schedule(const Attrs& attrs,
+                const Array<top::Tensor>& outs,
+                const Target& target)>;
 
 /*!
  * \brief Alternate the layout of operators or replace the
@@ -131,7 +131,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
 using FTVMAlterOpLayout = runtime::TypedPackedFunc<
   Expr(const Attrs& attrs,
        const Array<Expr>& args,
-       const Array<Tensor>& tinfos)>;
+       const Array<top::Tensor>& tinfos)>;
 
 /*!
  * \brief Convert the layout of operators or replace the
@@ -147,7 +147,7 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
 using FTVMConvertOpLayout = runtime::TypedPackedFunc<
   Expr(const Attrs& attrs,
        const Array<Expr>& args,
-       const Array<Tensor>& tinfos,
+       const Array<top::Tensor>& tinfos,
        const std::string& desired_layout)>;
 /*!
  * \brief Legalizes an expression with another expression. This function will be
@@ -206,9 +206,9 @@ enum AnyCodegenStrategy {
 using Shape = Array<IndexExpr>;
 
 using FShapeFunc = runtime::TypedPackedFunc<
-  Array<Tensor>(const Attrs& attrs,
-                const Array<Tensor>& inputs,
-                const Array<IndexExpr>& out_ndims)>;
+  Array<top::Tensor>(const Attrs& attrs,
+                     const Array<top::Tensor>& inputs,
+                     const Array<IndexExpr>& out_ndims)>;
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
deleted file mode 100644 (file)
index 3115b0a..0000000
+++ /dev/null
@@ -1,768 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/schedule.h
- * \brief Define a schedule.
- */
-// Acknowledgement: Many schedule primitives originate from Halide and Loopy.
-#ifndef TVM_SCHEDULE_H_
-#define TVM_SCHEDULE_H_
-
-#include <string>
-#include <unordered_map>
-#include "expr.h"
-#include "tensor.h"
-#include "tensor_intrin.h"
-
-namespace tvm {
-
-// Node container for Stage
-class StageNode;
-// Node container for Schedule
-class ScheduleNode;
-// Node container for IterVarRelation
-class IterVarRelationNode;
-// Attribute of itervar.
-class IterVarAttrNode;
-
-/*! \brief the attachment type */
-enum AttachType : int {
-  kGroupRoot = 1,
-  kInline = 2,
-  kInlinedAlready = 3,
-  kScope = 4,
-  kScanUpdate = 5
-};
-
-/*! \brief Stage, contains scheduling for a stage of computation. */
-class Stage : public ObjectRef {
- public:
-  Stage() {}
-  explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief create a new schedule for op.
-   * \param op The operator in the schedule
-   */
-  explicit Stage(Operation op);
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const StageNode* operator->() const;
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline StageNode* operator->();
-  /*!
-   * \brief set the memory scope of the stage
-   * \param scope The memory scope.
-   */
-  TVM_DLL Stage& set_scope(std::string scope);  // NOLINT(*)
-  /*!
-   * \brief specify the schedule to be computed at the parent schedule's scope.
-   * \param parent The parent schedule.
-   * \param scope The iteration point to carry the schedule.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& compute_at(Stage parent, IterVar scope);   // NOLINT(*)
-  /*!
-   * \brief Compute the function inline.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& compute_inline();   // NOLINT(*)
-  /*!
-   * \brief Compute the function at group root.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& compute_root();  // NOLINT(*)
-  /*!
-   * \brief Bind the IterVar to thread index.
-   *
-   * \param ivar The IterVar to be bound.
-   * \param thread_ivar The thread axis to be bound.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
-  /*!
-   * \brief Set the predicate to determine whether a store to the array should be performed.
-   *  Use this when there are multiple threads performing the same store and we only
-   *  need one of them to do the store.
-   *
-   * \note This is a dangerous scheduling primitive that can change behavior of program.
-   *    Only do when we are certain that thare are duplicated stores.
-   * \param predicate The condition to be checked.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
-  /*!
-   * \brief Specify environment threads that launched around the group's scope.
-   *  This can only be used in group stage.
-   * \param threads The threads to be launched around the scope.
-   * \note Each thread can only appear in one env_threads.
-   *    This is a beta feature.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& env_threads(Array<IterVar> threads);
-  /*!
-   * \brief Split the parent by factor, generate
-   * \param parent The parent iteration domain.
-   * \param factor The split factor of the loop.
-   * \param p_outer The result outer domain
-   * \param p_inner The result inner domain.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
-  /*!
-   * \brief Split the iteration with given number of parts.
-   *
-   * \param parent The parent domain.
-   * \param nparts The number of parts in the outer domain.
-   * \param p_outer The result outer domain.
-   * \param p_inner The result inner domain.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner);   // NOLINT(*)
-  /*!
-   * \brief Fuse the inner outer domain to the target
-   * \param outer The outer domain to be fused.
-   * \param inner The inner domain to be fused
-   * \param p_target The result target domain.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target);  // NOLINT(*)
-  /*!
-   * \brief Fuse all the axes together into a single axis.
-   *
-   * \param axes All the axes to be fused.
-   * \param p_target The result target domain.
-   *
-   * \note axes can be an empty array,
-   *       in that case, a singleton IterVar is created and
-   *       inserted to the outermost loop.
-   *       The fuse of empty array is used to support zero-dimension tensors.
-   *
-   * \return reference to self.
-   */
-  TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target);  // NOLINT(*)
-  /*!
-   * \brief Reorder the iteration
-   * \param order The order of iteration variable.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& reorder(const Array<IterVar>& order);   // NOLINT(*)
-  /*!
-   * \brief Perform tiling on two dimensions
-   *  The final loop order from outmost to inner most are
-   *  [x_outer, y_outer, x_inner, y_inner]
-   *
-   * \param x_parent The original x dimension
-   * \param y_parent The original y dimension
-   * \param x_factor The stride factor on x axis
-   * \param y_factor The stride factor on y axis
-   * \param p_x_outer Outer axis of x dimension
-   * \param p_y_outer Outer axis of y dimension
-   * \param p_x_inner Inner axis of x dimension
-   * \param p_y_inner Inner axis of y dimension
-   * \return reference to self.
-   */
-  TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent,   // NOLINT(*)
-                     PrimExpr x_factor, PrimExpr y_factor,
-                     IterVar* p_x_outer, IterVar* p_y_outer,
-                     IterVar* p_x_inner, IterVar* p_y_inner);
-  /*!
-   * \brief Vectorize iteration.
-   * \param var The axis to be vectorized.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& vectorize(IterVar var);   // NOLINT(*)
-  /*!
-   * \brief Replace computation of the current stage by tensor intrinsic f.
-   * \param var The axis marks beginning of tensorization.
-   *  Every operations inside the axis(include axis itself is tensorized).
-   * \param f The Tensor compute intrinsics.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f);   // NOLINT(*)
-  /*!
-   * \brief Unroll iteration.
-   * \param var The axis to be unrolled.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& unroll(IterVar var);   // NOLINT(*)
-  /*!
-   * \brief Parallelize iteration.
-   * \param var The axis to be parallelized.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& parallel(IterVar var);   // NOLINT(*)
-  /*!
-   * \brief Annotate the iteration with pragma
-   *
-   * \param var The axis to be parallelized.
-   * \param pragma_type The pragma type.
-   * \param pragma_value The pragma value
-   *
-   * \return reference to self.
-   */
-  TVM_DLL Stage& pragma(IterVar var,
-                       const std::string& pragma_type,
-                       const PrimExpr& pragma_value = PrimExpr());   // NOLINT(*)
-  /*!
-   * \brief Fetch data in advance.
-   * \param domain the tensor to be prefetched
-   * \param var the iteration point at which to apply prefetching
-   * \param offset the number of iterations be to fetched in advance
-   * \return reference to self
-   */
-  TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*)
-  /*!
-   * \brief Set alignment requirement for specific dimension.
-   *
-   *  Such that stride[axis] == k * factor + offset for some k.
-   *
-   * \param axis The dimension to be specified for alignment.
-   * \param factor The factor multiple of alignment
-   * \param offset The required offset factor.
-   * \return reference to self
-   */
-  TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
-  /*!
-   * \brief Compute current stage with double buffering.
-   * \return reference to self.
-   */
-  TVM_DLL Stage& double_buffer();   // NOLINT(*)
-  /*!
-   * \brief Schedule for OpenGL fragment shader.
-   * \return reference to self.
-   */
-  Stage& opengl(); // NOLINT(*)
-  /*!
-   * \brief whether the stage has been scheduled.
-   * \return whether the stage has been scheduled.
-   */
-  bool is_scheduled() const;
-  /*!
-   * \brief Get attachment spec of current stage.
-   *  If the stage compute at Group root, this function
-   *  will traverse the group function to get the
-   *  final spec from the group.
-   * \return A stage representing the attach spec of the group.
-   */
-  Stage GetAttachSpec() const;
-  // declare container type
-  using ContainerType = StageNode;
-};
-
-/*!
- * \brief Global schedule container
- *  For operations and all the operations they depend on.
- *  The schedule per Operation is named as stage.
- */
-class Schedule : public ObjectRef {
- public:
-  Schedule() {}
-  explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief Get a copy of current schedule.
-   * \return The copied schedule.
-   */
-  Schedule copy() const;
-  /*!
-   * \brief Get the stage corresponds to the op
-   * \param op The operation.
-   */
-  TVM_DLL Stage operator[](const Operation& op);
-  /*!
-   * \brief Short hand for getting the stage of tensor's operation.
-   * \param tensor The tensor
-   * \return The stage corresponding to the tensor's op
-   */
-  TVM_DLL Stage operator[](const Tensor& tensor) {
-    return this->operator[](tensor->op);
-  }
-  /*!
-   * \brief Create a new stage group for all intermediate
-   *  operations between inputs and outputs.
-   *
-   * \param outputs The output boundary of the group.
-   * \param inputs The input boundary of the group.
-   * \param include_inputs Whether include inputs if they are reachable from outputs.
-   * \return The new grouped stage.
-   */
-  TVM_DLL Stage create_group(const Array<Tensor>& outputs,
-                     const Array<Tensor>& inputs,
-                     bool include_inputs = false);
-  /*!
-   * \brief create a cache read of original tensor for readers.
-   *  This will mutate the body of the readers.
-   *  A new stage will be created for the tensor.
-   * \param tensor The tensor cached.
-   * \param scope The scope of the cache.
-   * \param readers The readers to redirect to the tensor.
-   * \return The created tensor.
-   */
-  TVM_DLL Tensor cache_read(const Tensor& tensor,
-                    const std::string& scope,
-                    const Array<Operation>& readers);
-  /*!
-   * \brief Create a cache write tensor for producing tensor.
-   *  The the tensor will take over body of original tensor op.
-   *
-   *  This function can be used to do data layout transformation.
-   *  If there is a split/fuse/reorder on the data parallel axis of tensor
-   *  before cache_write is called. The intermediate cache stores
-   *  the data in the layout as the iteration order of leave axis.
-   *  The data will be transformed back to the original layout in the original tensor.
-   *  User can further call compute_inline to inline the original layout and keep
-   *  the data stored in the transformed layout.
-   *
-   * \param tensor The tensors to be produced.
-   * \param scope The scope of the storage.
-   * \return The created tensor.
-   */
-  TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
-  /*!
-   * \brief Create a cache write tensor for producing tensor.
-   *  The the tensor will take over body of original tensor op.
-   *
-   *  This function can be used to do data layout transformation.
-   *  If there is a split/fuse/reorder on the data parallel axis of tensor
-   *  before cache_write is called. The intermediate cache stores
-   *  the data in the layout as the iteration order of leave axis.
-   *  The data will be transformed back to the original layout in the original tensor.
-   *  User can further call compute_inline to inline the original layout and keep
-   *  the data stored in the transformed layout.
-   *
-   * \param tensor The tensor to be produced.
-   * \param scope The scope of the storage.
-   * \return The created tensor.
-   */
-  TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
-  /*!
-   * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
-   * This will create a new stage that generated the new tensor with axis
-   * as the first dimension. The tensor's body will be rewritten as a reduction
-   * over the factored tensor.
-   *
-   *  P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
-   *
-   * \param tensor The tensor to be factored.
-   * \param axis The reduction axis in tensor's schedule to be factored.
-   * \param factor_axis The position where the new axis is placed.
-   * \return The created factored tensors.
-   */
-  TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
-                        const IterVar& axis,
-                        int factor_axis = 0);
-  /*!
-   * \brief Normalize the schedule.
-   *  This is needed before bound inference.
-   *  Insert necessary RebaseNode to make sure all leaf_iter_vars
-   *  are in form [0, extent)
-   *
-   * \return A normalized schedule, can be same as current one.
-   */
-  Schedule normalize();
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const ScheduleNode* operator->() const;
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline ScheduleNode* operator->();
-  // declare container type
-  using ContainerType = ScheduleNode;
-};
-
-/*!
- * \brief The schedule relation between IterVars
- *  can be Split, Fuse.
- */
-class IterVarRelation : public ObjectRef {
- public:
-  IterVarRelation() {}
-  explicit IterVarRelation(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const IterVarRelationNode* operator->() const;
-};
-
-/*!
- * \brief Additional scheduable attributes about IterVar.
- */
-class IterVarAttr : public ObjectRef {
- public:
-  IterVarAttr() {}
-  explicit IterVarAttr(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const IterVarAttrNode* operator->() const;
-};
-
-/*!
- * \brief represents a stage.
- *
- *  relations form a Directed acylic hypergraph in bipartite manner.
- *  With each node is represented by a IterVar,
- *  and each hyper-edge is represented by a IterVarRelation.
- *  The relations connects the IterVars in the graph.
- *
- *  Besides typical stage that corresponds to operations.
- *  There is also group stage, which groups stages together.
- *  Each stage's group(given by group) represent an constraint,
- *  the stage can only be attached to stages within the group.
- *
- *  The group stage node can be attached to IterVars as in normal stage.
- */
-class StageNode : public Object {
- public:
-  /*!
-   * \brief The operation of stage, can be different from original op.
-   *  If it is null, then this stage is a group stage.
-   */
-  Operation op;
-  /*!
-   * \brief The original operator.
-   *  The op field can change during schedule to alternate the dataflow,
-   *  while origin_op remains fixed.
-   */
-  Operation origin_op;
-  /*! \brief All the nodes in the iter var */
-  Array<IterVar> all_iter_vars;
-  /*! \brief The current active leaf iter vars in the stage. */
-  Array<IterVar> leaf_iter_vars;
-  /*!
-   * \brief Specify threads to be launched at the stage.
-   *  This is only valid for composite ops such as Scan.
-   * \note Experimental primitive: used for thread persistence.
-   */
-  Array<IterVar> env_threads;
-  /*!
-   * \brief The predicate under which store can happen
-   *  Use this when there can be duplicated threads doing the same store.
-   * \note Experimental primitive: used by cross thread-reduction.
-   */
-  PrimExpr store_predicate;
-  /*! \brief The relation bwteen of IterVars */
-  Array<IterVarRelation> relations;
-  /*! \brief additional attributes about iter var. */
-  Map<IterVar, IterVarAttr> iter_var_attrs;
-  /*! \brief The attachment type of the schedule */
-  AttachType attach_type{kGroupRoot};
-  /*! \brief The attach point of this schedule. */
-  IterVar attach_ivar;
-  /*! \brief The stage this node attaches to */
-  Stage attach_stage;
-  /*! \brief The thread storage scope level of the stage */
-  std::string scope;
-  /*! \brief Whether this is an output stage */
-  bool is_output{false};
-  /*! \brief Whether this is an OpenGL stage */
-  bool is_opengl{false};
-  /*! \brief Whether apply double buffer optimization to this stage */
-  bool double_buffer{false};
-  /*!
-   * \brief The parent group of the current stage.
-   *  The stage cannot be assigned to stages outside the group.
-   */
-  Stage group;
-  /*! \brief Number of direct child stages, only used for group stage.*/
-  int num_child_stages{0};
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("op", &op);
-    v->Visit("origin_op", &origin_op);
-    v->Visit("all_iter_vars", &all_iter_vars);
-    v->Visit("leaf_iter_vars", &leaf_iter_vars);
-    v->Visit("env_threads", &env_threads);
-    v->Visit("relations", &relations);
-    v->Visit("iter_var_attrs", &iter_var_attrs);
-    v->Visit("attach_type", &attach_type);
-    v->Visit("attach_ivar", &attach_ivar);
-    v->Visit("attach_stage", &attach_stage);
-    v->Visit("scope", &scope);
-    v->Visit("is_output", &is_output);
-    v->Visit("is_opengl", &is_opengl);
-    v->Visit("double_buffer", &double_buffer);
-    v->Visit("group", &group);
-    v->Visit("num_child_stages", &num_child_stages);
-  }
-
-  static constexpr const char* _type_key = "Stage";
-  TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
-};
-
-/*! \brief node container for schedule */
-class ScheduleNode : public Object {
- public:
-  /*! \brief The output operations in original data flow graph */
-  Array<Operation> outputs;
-  /*!
-   * \brief list of all stages for ops.
-   * The stages are sorted in dependency order.
-   */
-  Array<Stage> stages;
-  /*!
-   * \brief List of all stage groups.
-   */
-  Array<Stage> groups;
-  /*! \brief map of original operation to the stages */
-  Map<Operation, Stage> stage_map;
-  /*!
-   * \brief Internal stage map to map internal ops to stages.
-   *  This is created on demand and can be invalidated.
-   */
-  std::unordered_map<const Object*, Stage> op2stage_cache_;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("outputs", &outputs);
-    v->Visit("stages", &stages);
-    v->Visit("groups", &groups);
-    v->Visit("stage_map", &stage_map);
-  }
-
-  /*! \brief Initialize temp cache. */
-  void InitCache();
-  /*! \brief Invalidate temp cache. */
-  void InvalidateCache();
-
-  /*!
-   * \brief Check if the schedule contains an Operation.
-   * \param op The candidate Operation.
-   * \return true if the schedule has the Operation. Otherwise, false.
-   */
-  TVM_DLL bool Contain(const Operation& op) const;
-
-  /*!
-   * \brief Check if the schedule contains a Tensor.
-   * \param tensor The candidate tensor.
-   * \return true if the schedule has the tensor. Otherwise, false.
-   */
-  TVM_DLL bool Contain(const Tensor& tensor) const {
-    return Contain(tensor->op);
-  }
-
-  /*!
-   * \brief Create a schedule for array of ops(and their dependencies).
-   * \param ops The ops to be scheduled.
-   * \return sch The created Schedule.
-   */
-  TVM_DLL static Schedule make(Array<Operation> ops);
-
-  static constexpr const char* _type_key = "Schedule";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
-};
-
-/*!
- * \brief Create a schedule for array of ops(and their dependencies).
- * \param ops The ops to be scheduled.
- * \return sch The created Schedule.
- */
-inline Schedule create_schedule(Array<Operation> ops) {
-  return ScheduleNode::make(ops);
-}
-
-/*! \brief node container for IterVar attr */
-class IterVarAttrNode : public Object {
- public:
-  /*! \brief The iteration type. */
-  IterVarType iter_type{kDataPar};
-  /*! \brief The thread this iter Var binds, can be null */
-  IterVar bind_thread;
-  /*! \brief List of tensor to be prefetched in this loop */
-  Array<Tensor> prefetch_data;
-  /*! \brief The offset used in each prefetch */
-  Array<PrimExpr> prefetch_offset;
-  /*!
-   * \brief Tensor intrinsic used in tensorization,
-   *   when the axis is marked as Tensorized
-   */
-  TensorIntrin tensor_intrin;
-  /*! \brief Alignment factor of buffer dimension */
-  int dim_align_factor{0};
-  /*! \brief Alignment offset of buffer dimension */
-  int dim_align_offset{0};
-  /*!
-   * \brief Additional pragma keys, array of StringImm
-   */
-  Array<PrimExpr> pragma_keys;
-  /*!
-   * \brief Additional values of pragma, if any
-   */
-  Array<PrimExpr> pragma_values;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("iter_type", &iter_type);
-    v->Visit("bind_thread", &bind_thread);
-    v->Visit("prefetch_data", &prefetch_data);
-    v->Visit("prefetch_offset", &prefetch_offset);
-    v->Visit("tensor_intrin", &tensor_intrin);
-    v->Visit("dim_align_factor", &dim_align_factor);
-    v->Visit("dim_align_offset", &dim_align_offset);
-    v->Visit("pragma_keys", &pragma_keys);
-    v->Visit("pragma_values", &pragma_values);
-  }
-
-  static constexpr const char* _type_key = "IterVarAttr";
-  TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object);
-};
-
-/*! \brief base node of iteration var */
-class IterVarRelationNode : public Object {
- public:
-  static constexpr const char* _type_key = "IterVarRelation";
-  TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object);
-};
-
-/*!
- * \brief Split the parent domain into product of
- *  outer and iter.
- */
-class SplitNode : public IterVarRelationNode {
- public:
-  /*! \brief The parent domain */
-  IterVar parent;
-  /*! \brief The outer domain */
-  IterVar outer;
-  /*! \brief The inner domain */
-  IterVar inner;
-  /*! \brief The split factor */
-  PrimExpr factor;
-  /*! \brief Number of parts, only factor or nparts can be given */
-  PrimExpr nparts;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("parent", &parent);
-    v->Visit("outer", &outer);
-    v->Visit("inner", &inner);
-    v->Visit("factor", &factor);
-    v->Visit("nparts", &nparts);
-  }
-
-  static IterVarRelation make(IterVar parent,
-                              IterVar outer,
-                              IterVar inner,
-                              PrimExpr factor,
-                              PrimExpr nparts);
-
-  static constexpr const char* _type_key = "Split";
-  TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
-};
-
-/*!
- * \brief Fuse two domains into one domain.
- */
-class FuseNode : public IterVarRelationNode {
- public:
-  /*! \brief The outer domain */
-  IterVar outer;
-  /*! \brief The inner domain */
-  IterVar inner;
-  /*! \brief The target domain */
-  IterVar fused;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("outer", &outer);
-    v->Visit("inner", &inner);
-    v->Visit("fused", &fused);
-  }
-
-  static IterVarRelation make(
-      IterVar outer, IterVar inner, IterVar fused);
-
-  static constexpr const char* _type_key = "Fuse";
-  TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
-};
-
-/*!
- * \brief Rebase the iteration to make min to be 0.
- *  This is useful to normalize the Schedule
- *  to make every leaf variable's min to be 0.
- */
-class RebaseNode : public IterVarRelationNode {
- public:
-  /*! \brief The parent domain */
-  IterVar parent;
-  /*! \brief The inner domain */
-  IterVar rebased;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("parent", &parent);
-    v->Visit("rebased", &rebased);
-  }
-
-  static IterVarRelation make(IterVar parent, IterVar rebased);
-
-  static constexpr const char* _type_key = "Rebase";
-  TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
-};
-
-
-/*!
- * \brief Singleton iterator [0, 1)
- */
-class SingletonNode : public IterVarRelationNode {
- public:
-  /*! \brief The singleton iterator */
-  IterVar iter;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("iter", &iter);
-  }
-
-  static IterVarRelation make(IterVar iter);
-
-  static constexpr const char* _type_key = "Singleton";
-  TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
-};
-
-
-// implementations
-inline const StageNode* Stage::operator->() const {
-  return static_cast<const StageNode*>(get());
-}
-inline StageNode* Stage::operator->() {
-  return static_cast<StageNode*>(get_mutable());
-}
-
-inline const ScheduleNode* Schedule::operator->() const {
-  return static_cast<const ScheduleNode*>(get());
-}
-inline ScheduleNode* Schedule::operator->() {
-  return static_cast<ScheduleNode*>(get_mutable());
-}
-
-inline const IterVarRelationNode* IterVarRelation::operator->() const {
-  return static_cast<const IterVarRelationNode*>(get());
-}
-
-inline const IterVarAttrNode* IterVarAttr::operator->() const {
-  return static_cast<const IterVarAttrNode*>(get());
-}
-}  // namespace tvm
-#endif  // TVM_SCHEDULE_H_
diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h
deleted file mode 100644 (file)
index af2459b..0000000
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/schedule_pass.h
- * \brief  Collection of Schedule pass functions.
- *
- *  These passes works on the schedule hyper-graph
- *  and infers information such as bounds, check conditions
- *  read/write dependencies between the IterVar
- */
-#ifndef TVM_SCHEDULE_PASS_H_
-#define TVM_SCHEDULE_PASS_H_
-
-#include "schedule.h"
-
-namespace tvm {
-namespace schedule {
-
-/*!
- * \brief Infer the bound of all iteration variables relates to the schedule.
- *
- * \param sch The root schedule to infer all the bounds.
- * \return the result bound of the iteration Variable
- */
-Map<IterVar, Range> InferBound(const Schedule& sch);
-
-/*!
- * \brief Schedule s' dependent operations.
- *
- * \param s The schedule to be realized
- * \param dom_map The domain of each iter vars.
- * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 during lowering.
- *                                This is a debug feature for dataflow/axis analysis.
- *                                Note: If this is true, The lowered IR may be incorrect,
- *                                because we will also delete the init part of reduction
- * \return the result Stmt
- */
-Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
-
-/*!
- * \brief To automatically inline the element-wise operations.
- *
- * \param sch The schedule to be inlined.
- */
-void AutoInlineElemWise(Schedule sch);
-
-/*!
- * \brief To automatically inline operations with injective writes
- *   (i.e. writes without reduction or sequential loops). Note
- *   that in this case, guarantees about contiguity, transpose, stride,
- *   alignemnt and memory footprint in general do not hold.
- *
- * \param sch The schedule to be inlined.
- */
-TVM_DLL void AutoInlineInjective(Schedule sch);
-
-}  // namespace schedule
-}  // namespace tvm
-#endif  // TVM_SCHEDULE_PASS_H_
diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h
deleted file mode 100644 (file)
index 91c0c96..0000000
+++ /dev/null
@@ -1,268 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/tensor.h
- * \brief Dataflow tensor object
- */
-#ifndef TVM_TENSOR_H_
-#define TVM_TENSOR_H_
-
-#include <tvm/node/container.h>
-#include <tvm/arith/bound.h>
-
-#include <string>
-#include <vector>
-#include <utility>
-#include <type_traits>
-
-#include "expr.h"
-#include "expr_operator.h"
-
-namespace tvm {
-
-// Internal node container of Tensor
-class TensorNode;
-// internal node container for Operation
-class OperationNode;
-
-/*!
- * \brief Tensor structure representing a possible input,
- *  or intermediate computation result.
- */
-class Tensor : public ObjectRef {
- public:
-  /*! \brief default constructor, used internally */
-  Tensor() {}
-  explicit Tensor(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const TensorNode* operator->() const;
-  /*!
-   * \brief check if two tensors equals each other.
-   * \param other tensor to be checked.
-   * \return whether the two tensors equals each other.
-   */
-  inline bool operator==(const Tensor& other) const;
-  /*!
-   * \brief check if two tensors are different.
-   * \param other tensor to be checked.
-   * \return whether the two tensors are different.
-   */
-  inline bool operator!=(const Tensor& other) const;
-  /*! \return The dimension of the tensor */
-  inline size_t ndim() const;
-  /*!
-   * \brief Take elements from the tensor
-   * \param args The indices
-   * \return the result expression representing tensor read.
-   */
-  template<typename... Args>
-  inline PrimExpr operator()(Args&& ...args) const {
-    Array<PrimExpr> indices{std::forward<Args>(args)...};
-    return operator()(indices);
-  }
-  /*!
-   * \brief Take elements from the tensor
-   * \param indices the indices.
-   * \return the result expression representing tensor read.
-   */
-  TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const;
-  /*!
-   * \brief Take elements from the tensor
-   * \param indices the indices.
-   * \return the result expression representing tensor read.
-   */
-  TVM_DLL PrimExpr operator()(Array<Var> indices) const;
-  /*!
-   * \brief data structure to represent a slice that fixes first k coordinates.
-   *  This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
-   */
-  class Slice {
-   public:
-    // construct via tensor and indices
-    Slice(const Tensor& tensor, std::vector<PrimExpr> indices)
-        : tensor_(tensor), indices_(indices) {}
-    /*!
-     * \brief get i-th slice from the current slice.
-     * \param i the index of the coordinate
-     * \return the subsequent slice.
-     */
-    inline Slice operator[](PrimExpr i) {
-      std::vector<PrimExpr> other = indices_;
-      other.emplace_back(i);
-      return Slice(tensor_, other);
-    }
-    /*!
-     * \brief Convert slice to expression.
-     *  This is only valid when all the coordinates are fully specified.
-     * \return the corresponding expression of this slice.
-     */
-    inline operator PrimExpr() const {
-      return tensor_(indices_);
-    }
-
-   private:
-    const Tensor& tensor_;
-    std::vector<PrimExpr> indices_;
-  };
-  /*!
-   * \brief get i-th slice from the current Tensor.
-   * \param i the index of the coordinate
-   * \return the subsequent slice.
-   */
-  inline Slice operator[](PrimExpr i) const {
-    return Slice(*this, {i});
-  }
-  /*! \brief specify container node */
-  using ContainerType = TensorNode;
-};
-
-/*! \brief Operation that produces tensors */
-class Operation : public ir::FunctionRef {
- public:
-  /*! \brief default constructor  */
-  Operation() {}
-  explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const OperationNode* operator->() const;
-  /*!
-   * \brief get the i-th output of the operation.
-   * \param i the output index.
-   * \return The i-th output.
-   */
-  TVM_DLL Tensor output(size_t i) const;
-  /*! \brief specify container node */
-  using ContainerType = OperationNode;
-};
-
-/*! \brief Node to represent a tensor */
-class TensorNode : public Object {
- public:
-  /*! \brief The shape of the tensor */
-  Array<PrimExpr> shape;
-  /*! \brief data type in the content of the tensor */
-  DataType dtype;
-  /*! \brief the source operation, can be None */
-  Operation op;
-  /*! \brief the output index from source operation */
-  int value_index{0};
-  /*! \brief constructor */
-  TensorNode() {}
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("shape", &shape);
-    v->Visit("dtype", &dtype);
-    v->Visit("op", &op);
-    v->Visit("value_index", &value_index);
-  }
-  TVM_DLL static Tensor make(Array<PrimExpr> shape,
-                             DataType dtype,
-                             Operation op,
-                             int value_index);
-
-  static constexpr const char* _type_key = "Tensor";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
-};
-
-
-// Implementations of inline functions
-inline const TensorNode* Tensor::operator->() const {
-  return static_cast<const TensorNode*>(get());
-}
-
-inline size_t Tensor::ndim() const {
-  return (*this)->shape.size();
-}
-
-inline bool Tensor::operator==(const Tensor& other) const {
-  if (get() == other.get()) return true;
-  if (get() == nullptr || other.get() == nullptr) return false;
-  if ((*this)->op.defined() || other->op.defined()) {
-    return (*this)->op == other->op &&
-        (*this)->value_index == other->value_index;
-  } else {
-    return false;
-  }
-}
-
-inline bool Tensor::operator!=(const Tensor& other) const {
-  return !(*this == other);
-}
-
-// macro to turn every operation of slice to expression
-#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op)                              \
-  inline PrimExpr operator Op (const Tensor::Slice& a) {           \
-    return Op a.operator PrimExpr() ;                              \
-  }                                                                     \
-
-#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op)                             \
-  template<typename T>                                                  \
-  inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \
-    return a.operator PrimExpr() Op b;                             \
-  }                                                                     \
-  template<typename T>                                                  \
-  inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) {  \
-    return a Op b.operator PrimExpr();                                \
-  }                                                                        \
-  inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
-    return a.operator PrimExpr() Op b.operator PrimExpr();                  \
-  }
-
-DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
-DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(!=);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(&&);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(||);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(>>);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(<<);
-DEFINE_OVERLOAD_SLICE_BINARY_OP(>);  // NOLINT(*)
-DEFINE_OVERLOAD_SLICE_BINARY_OP(<);  // NOLINT(*)
-
-}  // namespace tvm
-
-namespace std {
-template <>
-struct hash<::tvm::Operation> : public ::tvm::ObjectHash {
-};
-
-template <>
-struct hash<::tvm::Tensor> {
-  std::size_t operator()(const ::tvm::Tensor& k) const {
-    ::tvm::ObjectHash hasher;
-    if (k.defined() && k->op.defined()) {
-      return hasher(k->op);
-    } else{
-      return hasher(k);
-    }
-  }
-};
-}  // namespace std
-#endif  // TVM_TENSOR_H_
diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h
deleted file mode 100644 (file)
index 879e206..0000000
+++ /dev/null
@@ -1,177 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/tensor_intrin.h
- * \brief Tensor intrinsic operations.
- */
-#ifndef TVM_TENSOR_INTRIN_H_
-#define TVM_TENSOR_INTRIN_H_
-
-#include <string>
-#include "tensor.h"
-#include "buffer.h"
-
-namespace tvm {
-
-// Internal node container of tensor intrinsics.
-class TensorIntrinNode;
-
-/*! \brief Tensor intrinsic node. */
-class TensorIntrin : public ObjectRef {
- public:
-  TensorIntrin() {}
-  explicit TensorIntrin(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const TensorIntrinNode* operator->() const;
-
-  /*! \brief specify container node */
-  using ContainerType = TensorIntrinNode;
-};
-
-/*! \brief Node to represent a Tensor intrinsic operator */
-class TensorIntrinNode : public Object {
- public:
-  /*! \brief The name of the intrinsic */
-  std::string name;
-  /*! \brief The operation this intrinsics is carrying out */
-  Operation op;
-  /*! \brief List of inputs of operator, placeholder in postdfs order */
-  Array<Tensor> inputs;
-  /*!
-   * \brief Symbolic buffers of each output/input tensor
-   *  buffers[0:len(inputs)] are buffers of the inputs.
-   *  buffers[len(inputs):] are buffers of each output.
-   *
-   * \note When a field in Buffer is Var, it means we can be flexible
-   *  wrt that field and Var can occur in body.
-   *  When it is a constant, it means we can only take data in that shape.
-   */
-  Array<Buffer> buffers;
-  /*! \brief List of scalar variables, used in body. These placeholders
-   *  will be bound to expressions passed in when the TensorIntrin is called
-   * from a TensorComputeOp.
-   */
-  Array<Var> scalar_params;
-  /*! \brief The normal statement to execute the intrinsic */
-  Stmt body;
-  /*!
-   * \brief Special statement for reduction op, can be None
-   *  reset the value of output buffer to identity value.
-   */
-  Stmt reduce_init;
-  /*!
-   * \brief Special statement for reduction op, can be None
-   *  Reduce: do a reduction of current output buffer with the result.
-   */
-  Stmt reduce_update;
-  /*! \brief constructor */
-  TensorIntrinNode() {}
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("op", &op);
-    v->Visit("inputs", &inputs);
-    v->Visit("buffers", &buffers);
-    v->Visit("scalar_params", &scalar_params);
-    v->Visit("body", &body);
-    v->Visit("reduce_init", &reduce_init);
-    v->Visit("reduce_update", &reduce_update);
-  }
-
-  TVM_DLL static TensorIntrin make(std::string name,
-                                   Operation op,
-                                   Array<Tensor> inputs,
-                                   Array<Buffer> buffers,
-                                   Array<Var> scalar_params,
-                                   Stmt body,
-                                   Stmt reduce_init,
-                                   Stmt reduce_update);
-
-  static constexpr const char* _type_key = "TensorIntrin";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
-};
-
-inline const TensorIntrinNode* TensorIntrin::operator->() const {
-  return static_cast<const TensorIntrinNode*>(get());
-}
-
-// Internal node container of tensor intrinsic calling.
-class TensorIntrinCallNode;
-
-/*! \brief Tensor intrinsic calling node. */
-class TensorIntrinCall : public ObjectRef {
- public:
-  TensorIntrinCall() {}
-  explicit TensorIntrinCall(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const TensorIntrinCallNode* operator->() const;
-
-  /*! \brief specify container node */
-  using ContainerType = TensorIntrinCallNode;
-};
-
-class TensorIntrinCallNode : public Object {
- public:
-  /*! \brief the tensor intrinsic */
-  TensorIntrin intrin;
-  /*! \brief input tensors of the intrinsic */
-  Array<Tensor> tensors;
-  /*! \brief regions of input tensors */
-  Array<Region> regions;
-
-
-  /*!
-   * \brief IterVar on each reduction axis, if the
-   * intrin will use the reduce axis
-   */
-  Array<IterVar> reduce_axis;
-
-  /*! \brief scalar expression inputs */
-  Array<PrimExpr> scalar_inputs;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("intrin", &intrin);
-    v->Visit("tensors", &tensors);
-    v->Visit("regions", &regions);
-    v->Visit("reduce_axis", &reduce_axis);
-    v->Visit("scalar_inputs", &scalar_inputs);
-  }
-  static TensorIntrinCall make(TensorIntrin intrin,
-                               Array<Tensor> tensors,
-                               Array<Region> regions,
-                               Array<IterVar> reduce_axis,
-                               Array<PrimExpr> scalar_inputs);
-
-  static constexpr const char* _type_key = "TensorIntrinCall";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object);
-};
-
-inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
-  return static_cast<const TensorIntrinCallNode*>(get());
-}
-
-}  // namespace tvm
-#endif  // TVM_TENSOR_INTRIN_H_
diff --git a/include/tvm/top/operation.h b/include/tvm/top/operation.h
new file mode 100644 (file)
index 0000000..2cee218
--- /dev/null
@@ -0,0 +1,663 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/top/operation.h
+ * \brief Operation node can generate one or multiple Tensors
+ */
+#ifndef TVM_TOP_OPERATION_H_
+#define TVM_TOP_OPERATION_H_
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/top/tensor.h>
+#include <tvm/top/schedule.h>
+
+#include <tvm/expr.h>
+#include <tvm/expr_operator.h>
+#include <tvm/buffer.h>
+
+#include <string>
+#include <vector>
+#include <unordered_map>
+
+
+
+namespace tvm {
+namespace top {
+
+using arith::IntSet;
+
+/*!
+ * \brief Temporary data structure to store union
+ *  of bounds of each axis of Tensor.
+ */
+struct TensorDom {
+  // constructor
+  explicit TensorDom(int ndim)
+      : data(ndim) {}
+  /*! \brief The domain data */
+  std::vector<std::vector<IntSet> > data;
+};
+
+/*!
+ * \brief Base class of all operation nodes
+ */
+class OperationNode : public ir::FunctionBaseNode {
+ public:
+  /*! \brief optional name of the operation */
+  std::string name;
+  /*! \brief optional tag of the operation */
+  std::string tag;
+  /*! \brief additional attributes of the operation*/
+  Map<std::string, ObjectRef> attrs;
+  /*! \return name of the operation */
+  const std::string& func_name() const final {
+    return name;
+  }
+  /*!
+   * \return The list of iteration variable at root
+   * \note root_iter_vars decides the shape of the outputs.
+   */
+  virtual Array<IterVar> root_iter_vars() const = 0;
+  /*!
+   * \brief Get data type. i-th output tensor.
+   * \param i The output index.
+   * \return type of i-th output.
+   */
+  virtual DataType output_dtype(size_t i) const = 0;
+  /*!
+   * \brief Get shape of i-th output tensor.
+   * \param i The output index.
+   * \return shape of i-th output.
+   */
+  virtual Array<PrimExpr> output_shape(size_t i) const = 0;
+  /*!
+   * \brief List all the input Tensors.
+   * \return List of input tensors.
+   */
+  virtual Array<Tensor> InputTensors() const = 0;
+  /*!
+   * \brief Replace the input of the operation by pattern specified by rmap.
+   *
+   * \param self The reference to self.
+   * \param rmap The replacement map.
+   * \return self if nothing is replaced, otherwise return replaced op.
+   */
+  virtual Operation ReplaceInputs(
+      const Operation& self,
+      const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
+  /*!
+   * \brief Propagate the bounds to inputs
+   * \param self The reference to self.
+   * \param analyzer The analyzer to be used in the function.
+   * \param dom_map the domain map of Variables(corresponds to root_iter_vars)
+   * \param out_dom_map The output domain.
+   *  The function is only asked to fill the bounds for Tensors that
+   *  is already in the out_dom_map
+   */
+  virtual void PropBoundToInputs(
+      const Operation& self,
+      arith::Analyzer* analyzer,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
+      std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
+  /*!
+   * \brief Gather the bound from output tensor.
+   *  Set the range of each root_iter_vars in the op to out_dom_map
+   *
+   * \param self The reference to self.
+   * \param tensor_dom Domain map of Tensor->access set of each dimension.
+   * \param out_dom_map The output domain map of each IterVar to be setted.
+   */
+  virtual void GatherBound(
+      const Operation& self,
+      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+      std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
+  /*!
+   * \brief Build the Realize statement that realizes
+   *   the op's output tensors.
+   * \param stage the op's stage.
+   * \param realize_map The realization domain map of the operators.
+   * \param body The body that is going to get
+   * \return A realization statement that wraps body.
+   */
+  virtual Stmt BuildRealize(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& realize_map,
+      const Stmt& body) const = 0;
+  /*!
+   * \brief Build the statement that provide the output tensors.
+   * \param stage The schedule stage of the op.
+   * \param dom_map The domain map of all iteration domains.
+   * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
+   * \return A statement that add production and wraps consumer.
+   */
+  virtual Stmt BuildProvide(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& dom_map,
+      bool debug_keep_trivial_loop) const = 0;
+
+  static constexpr const char* _type_key = "Operation";
+
+  TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object);
+};
+
+/*!
+ * \brief A placeholder op represents an input placeholder.
+ */
+class PlaceholderOpNode : public OperationNode {
+ public:
+  /*! \brief The shape of the input */
+  Array<PrimExpr> shape;
+  /*! \brief The data type of the input. */
+  DataType dtype;
+  // override behavior.
+  int num_outputs() const final;
+  Array<IterVar> root_iter_vars() const final;
+  DataType output_dtype(size_t i) const final;
+  Array<PrimExpr> output_shape(size_t i) const final;
+  Array<Tensor> InputTensors() const final;
+  Operation ReplaceInputs(
+      const Operation& self,
+      const std::unordered_map<Tensor, Tensor>& rmap) const final;
+  void PropBoundToInputs(
+      const Operation& self,
+      arith::Analyzer* analyzer,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
+      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+  void GatherBound(
+      const Operation& self,
+      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+      std::unordered_map<IterVar, Range>* out_dom_map) const final;
+  Stmt BuildRealize(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& realize_map,
+      const Stmt& body) const final;
+  Stmt BuildProvide(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& dom_map,
+      bool debug_keep_trivial_loop) const final;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("tag", &tag);
+    v->Visit("attrs", &attrs);
+    v->Visit("shape", &shape);
+    v->Visit("dtype", &dtype);
+  }
+  static Operation make(std::string name,
+                        Array<PrimExpr> shape,
+                        DataType dtype);
+
+  static constexpr const char* _type_key = "PlaceholderOp";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
+};
+
+/*!
+ * \brief A Compute op that compute a tensor on certain domain.
+ * This is the base class for ComputeOp (operating on a scalar at a time) and
+ * TensorComputeOp (operating on a TensorSlice at a time)
+ */
+class TVM_DLL BaseComputeOpNode : public OperationNode {
+ public:
+  /*! \brief IterVar on each axis */
+  Array<IterVar> axis;
+  /*! \brief IterVar on each reduction axis, if the body is a Reduce */
+  Array<IterVar> reduce_axis;
+  // override functions
+  Array<IterVar> root_iter_vars() const final;
+  Array<PrimExpr> output_shape(size_t idx) const final;
+  void GatherBound(
+          const Operation& self,
+          const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+          std::unordered_map<IterVar, Range>* out_dom_map) const final;
+  Stmt BuildRealize(
+          const Stage& stage,
+          const std::unordered_map<IterVar, Range>& realize_map,
+          const Stmt& body) const final;
+  virtual size_t num_schedulable_dims() const = 0;
+
+  static constexpr const char* _type_key = "BaseComputeOp";
+  TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
+};
+
+
+/*!
+ * \brief A Compute op that compute a tensor on certain domain.
+ */
+class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
+ public:
+  /*! \brief the compute expression */
+  Array<PrimExpr> body;
+  /*! \brief constructor */
+  ComputeOpNode() {}
+  // override functions
+  int num_outputs() const final;
+  DataType output_dtype(size_t i) const final;
+  Array<Tensor> InputTensors() const final;
+  Operation ReplaceInputs(
+      const Operation& self,
+      const std::unordered_map<Tensor, Tensor>& rmap) const final;
+  void PropBoundToInputs(
+      const Operation& self,
+      arith::Analyzer* analyzer,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
+      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+  Stmt BuildProvide(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& dom_map,
+      bool debug_keep_trivial_loop) const final;
+  size_t num_schedulable_dims() const final;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("tag", &tag);
+    v->Visit("attrs", &attrs);
+    v->Visit("axis", &axis);
+    v->Visit("reduce_axis", &reduce_axis);
+    v->Visit("body", &body);
+  }
+  static Operation make(std::string name,
+                        std::string tag,
+                        Map<std::string, ObjectRef> attrs,
+                        Array<IterVar> axis,
+                        Array<PrimExpr> body);
+
+  static constexpr const char* _type_key = "ComputeOp";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
+};
+
+/*!
+ * \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
+ */
+class TensorComputeOpNode : public BaseComputeOpNode {
+ public:
+  /*! \brief number of axes that can be scheduled */
+  int schedulable_ndim;
+  /*! \brief TensorIntrin used to compute */
+  TensorIntrin intrin;
+  /*! \brief input tensors of intrin */
+  Array<Tensor> inputs;
+  /*! \brief region of input tensors */
+  Array<Region> input_regions;
+  /*! \brief scalar expression inputs */
+  Array<PrimExpr> scalar_inputs;
+  /*! \brief constructor */
+  TensorComputeOpNode() {}
+  // override functions
+  int num_outputs() const final;
+  DataType output_dtype(size_t i) const final;
+  Array<Tensor> InputTensors() const final;
+  Operation ReplaceInputs(
+      const Operation& self,
+      const std::unordered_map<Tensor, Tensor>& rmap) const final;
+  void PropBoundToInputs(
+      const Operation& self,
+      arith::Analyzer* analyzer,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
+      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+  Stmt BuildProvide(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& dom_map,
+      bool debug_keep_trivial_loop) const final;
+  size_t num_schedulable_dims() const final;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("tag", &tag);
+    v->Visit("axis", &axis);
+    v->Visit("reduce_axis", &reduce_axis);
+    v->Visit("schedulable_ndim", &schedulable_ndim);
+    v->Visit("intrin", &intrin);
+    v->Visit("inputs", &inputs);
+    v->Visit("input_regions", &input_regions);
+    v->Visit("scalar_inputs", &scalar_inputs);
+  }
+  static Operation make(std::string name,
+                        std::string tag,
+                        Array<IterVar> axis,
+                        Array<IterVar> reduce_axis,
+                        int schedulable_ndim,
+                        TensorIntrin intrin,
+                        Array<Tensor> tensors,
+                        Array<Region> regions,
+                        Array<PrimExpr> scalar_inputs);
+
+  static constexpr const char* _type_key = "TensorComputeOp";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
+};
+
+/*!
+ * \brief Symbolic scan.
+ */
+class ScanOpNode : public OperationNode {
+ public:
+  /*! \brief IterVar to scan over */
+  IterVar scan_axis;
+  /*! \brief the initialization tensors */
+  Array<Tensor> init;
+  /*! \brief the update function represented by tensor */
+  Array<Tensor> update;
+  /*! \brief The placeholder to refer as states in update. */
+  Array<Tensor> state_placeholder;
+  /*!
+   * \brief the inputs to the scan, these are optionally provided
+   *  But they can be helpful to provide hints to speedup get of scan body.
+   */
+  Array<Tensor> inputs;
+  /*!
+   * \brief Spatial axis to indicate spatial dimension of each output.
+   *  They corresponds to flattened spatial axis of the outputs.
+   *
+   *  [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
+   *  These are auxiliary data structure for storing result of bound inference.
+   *  They do not corresponds to splittable iterations, thus the name comes
+   *  with underscore.
+   */
+  Array<IterVar> spatial_axis_;
+  /*! \brief constructor */
+  ScanOpNode() {}
+  // override behavior.
+  int num_outputs() const final;
+  Array<IterVar> root_iter_vars() const final;
+  DataType output_dtype(size_t i) const final;
+  Array<PrimExpr> output_shape(size_t i) const final;
+  Array<Tensor> InputTensors() const final;
+  Operation ReplaceInputs(
+      const Operation& self,
+      const std::unordered_map<Tensor, Tensor>& rmap) const final;
+  void PropBoundToInputs(
+      const Operation& self,
+      arith::Analyzer* analyzer,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
+      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+  void GatherBound(
+      const Operation& self,
+      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+      std::unordered_map<IterVar, Range>* out_dom_map) const final;
+  Stmt BuildRealize(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& realize_map,
+      const Stmt& body) const final;
+  Stmt BuildProvide(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& dom_map,
+      bool debug_keep_trivial_loop) const final;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("tag", &tag);
+    v->Visit("attrs", &attrs);
+    v->Visit("scan_axis", &scan_axis);
+    v->Visit("init", &init);
+    v->Visit("update", &update);
+    v->Visit("state_placeholder", &state_placeholder);
+    v->Visit("inputs", &inputs);
+    v->Visit("spatial_axis_", &spatial_axis_);
+  }
+  static Operation make(std::string name,
+                        std::string tag,
+                        Map<std::string, ObjectRef> attrs,
+                        IterVar axis,
+                        Array<Tensor> init,
+                        Array<Tensor> update,
+                        Array<Tensor> state_placeholder,
+                        Array<Tensor> input);
+
+  static constexpr const char* _type_key = "ScanOp";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
+};
+
+/*!
+ * \brief External computation that cannot be splitted.
+ */
+class ExternOpNode : public OperationNode {
+ public:
+  /*! \brief The input tensors */
+  Array<Tensor> inputs;
+  /*! \brief Symbolic placeholder representation of inputs */
+  Array<Buffer> input_placeholders;
+  /*! \brief Symbolic placeholder representation of outputs */
+  Array<Buffer> output_placeholders;
+  /*! \brief the statement that generates the computation. */
+  Stmt body;
+
+  /*! \brief constructor */
+  ExternOpNode() {}
+  // override functions
+  int num_outputs() const final;
+  Array<IterVar> root_iter_vars() const final;
+  DataType output_dtype(size_t i) const final;
+  Array<PrimExpr> output_shape(size_t i) const final;
+  Array<Tensor> InputTensors() const final;
+  Operation ReplaceInputs(
+      const Operation& self,
+      const std::unordered_map<Tensor, Tensor>& rmap) const final;
+  void PropBoundToInputs(
+      const Operation& self,
+      arith::Analyzer* analyzer,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
+      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+  void GatherBound(
+      const Operation& self,
+      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+      std::unordered_map<IterVar, Range>* out_dom_map) const final;
+  Stmt BuildRealize(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& realize_map,
+      const Stmt& body) const final;
+  Stmt BuildProvide(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& dom_map,
+      bool debug_keep_trivial_loop) const final;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("tag", &tag);
+    v->Visit("attrs", &attrs);
+    v->Visit("inputs", &inputs);
+    v->Visit("input_placeholders", &input_placeholders);
+    v->Visit("output_placeholders", &output_placeholders);
+    v->Visit("body", &body);
+  }
+  TVM_DLL static Operation make(std::string name,
+                               std::string tag,
+                               Map<std::string, ObjectRef> attrs,
+                               Array<Tensor> inputs,
+                               Array<Buffer> input_placeholders,
+                               Array<Buffer> output_placeholders,
+                               Stmt body);
+
+  static constexpr const char* _type_key = "ExternOp";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
+};
+
+/*!
+ * \brief A computation operator that generated by hybrid script.
+ */
+class HybridOpNode : public OperationNode {
+ public:
+  /*! \brief The input tensors */
+  Array<Tensor> inputs;
+  /*! \brief Symbolic placeholder representation of outputs */
+  Array<Tensor> outputs;
+  /*! \brief The axis of iterations */
+  Array<IterVar> axis;
+  /*! \brief the statement that generates the computation. This is
+   * slightly different from the body in ExternOpNode. All the output
+   * tensors keep its own name specified by users in the script.
+   * However, when compilation, these tensors will be placed by those
+   * actual output tensors. */
+  Stmt body;
+
+  /*! \brief constructor */
+  HybridOpNode() {}
+  // override functions
+  int num_outputs() const final;
+  Array<IterVar> root_iter_vars() const final;
+  DataType output_dtype(size_t i) const final;
+  Array<PrimExpr> output_shape(size_t i) const final;
+  Array<Tensor> InputTensors() const final;
+  Operation ReplaceInputs(
+      const Operation& self,
+      const std::unordered_map<Tensor, Tensor>& rmap) const final;
+  void PropBoundToInputs(
+      const Operation& self,
+      arith::Analyzer* analyzer,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
+      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+  void GatherBound(
+      const Operation& self,
+      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+      std::unordered_map<IterVar, Range>* out_dom_map) const final;
+  Stmt BuildRealize(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& realize_map,
+      const Stmt& body) const final;
+  Stmt BuildProvide(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& dom_map,
+      bool debug_keep_trivial_loop) const final;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("tag", &tag);
+    v->Visit("attrs", &attrs);
+    v->Visit("inputs", &inputs);
+    v->Visit("outputs", &outputs);
+    v->Visit("axis", &axis);
+    v->Visit("body", &body);
+  }
+  TVM_DLL static Operation make(std::string name,
+                                std::string tag,
+                                Map<std::string, ObjectRef> attrs,
+                                Array<Tensor> inputs,
+                                Array<Tensor> outputs,
+                                Stmt body);
+
+  static constexpr const char* _type_key = "HybridOp";
+  TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
+};
+
+/*! \brief The compute function to specify the input source of a Tensor */
+using FCompute = std::function<PrimExpr (const Array<Var>& i)>;
+
+/*! \brief The compute function to specify the inputs source of Tensors */
+using FBatchCompute = std::function<Array<PrimExpr> (const Array<Var>& i)>;
+
+/*!
+ * \brief create a place holder tensor.
+ * \param shape The shape of the tensor.
+ * \param dtype the data type of the tensor.
+ * \param name The name of the Tensor.
+ */
+TVM_DLL Tensor placeholder(Array<PrimExpr> shape,
+                           DataType dtype = DataType::Float(32),
+                           std::string name = "placeholder");
+
+/*!
+ * \brief Construct a new tensor by computing over shape,
+ *  using the computation rule: result_tensor[axis] = fcompute(axis)
+ * \param shape Shape of the tensor.
+ * \param fcompute The compute function to create the tensor.
+ * \param name The optional name of the tensor.
+ * \param tag The optional tag of the tensor.
+ * \param attrs Optional additional attributes of the compute.
+ */
+TVM_DLL Tensor compute(Array<PrimExpr> shape,
+                       FCompute fcompute,
+                       std::string name = "tensor",
+                       std::string tag = "",
+                       Map<std::string, ObjectRef> attrs = {});
+
+/*!
+ * \brief Construct a new tensor by computing over shape,
+ *  using the computation rule: result_tensor[axis] = fcompute(axis)
+ * \param shape Shape of the tensor.
+ * \param fcompute The compute function to create the tensors.
+ * \param name The optional name of the tensor.
+ * \param tag The optional tag of the tensor.
+ * \param attrs Optional additional attributes of the compute.
+ */
+TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape,
+                              FBatchCompute fcompute,
+                              std::string name = "tensor",
+                              std::string tag = "",
+                              Map<std::string, ObjectRef> attrs = {});
+
+/*!
+ * \brief Construct new tensors by scan.
+ *
+ * \param init The intialize tensor of first K steps.
+ * \param update The update tensor indicated the updated result after each timestamp.
+ * \param state_placeholder The placeholder for the states.
+ * \param inputs The inputs to the scan body, this is optional,
+ *    but recommended to provide concrete information about scan body.
+ * \param name The optional name of the tensor.
+ * \param tag The optional tag of the tensor.
+ * \param attrs Optional additional attributes of the compute.
+ */
+TVM_DLL Array<Tensor> scan(Array<Tensor> init,
+                           Array<Tensor> update,
+                           Array<Tensor> state_placeholder,
+                           Array<Tensor> inputs = Array<Tensor>(),
+                           std::string name = "scan",
+                           std::string tag = "",
+                           Map<std::string, ObjectRef> attrs = {});
+
+// same as compute, specialized for different fcompute function
+inline Tensor compute(Array<PrimExpr> shape,
+                      std::function<PrimExpr(Var)> f,
+                      std::string name = "tensor",
+                      std::string tag = "",
+                      Map<std::string, ObjectRef> attrs = {}) {
+  FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
+  return compute(shape, fc, name, tag, attrs);
+}
+inline Tensor compute(Array<PrimExpr> shape,
+                      std::function<PrimExpr(Var, Var)> f,
+                      std::string name = "tensor",
+                      std::string tag = "",
+                      Map<std::string, ObjectRef> attrs = {}) {
+  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
+  return compute(shape, fc, name, tag, attrs);
+}
+inline Tensor compute(Array<PrimExpr> shape,
+                      std::function<PrimExpr(Var, Var, Var)> f,
+                      std::string name = "tensor",
+                      std::string tag = "",
+                      Map<std::string, ObjectRef> attrs = {}) {
+  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
+  return  compute(shape, fc, name, tag, attrs);
+}
+inline Tensor compute(Array<PrimExpr> shape,
+                      std::function<PrimExpr(Var, Var, Var, Var)> f,
+                      std::string name = "tensor",
+                      std::string tag = "",
+                      Map<std::string, ObjectRef> attrs = {}) {
+  FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
+  return compute(shape, fc, name, tag, attrs);
+}
+
+// inline function.
+inline const OperationNode* Operation::operator->() const {
+  return static_cast<const OperationNode*>(get());
+}
+}  // namespace top
+}  // namespace tvm
+#endif  // TVM_TOP_OPERATION_H_
diff --git a/include/tvm/top/schedule.h b/include/tvm/top/schedule.h
new file mode 100644 (file)
index 0000000..2adaa13
--- /dev/null
@@ -0,0 +1,772 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/top/schedule.h
+ * \brief Define a schedule.
+ */
+// Acknowledgement: Many schedule primitives originate from Halide and Loopy.
+#ifndef TVM_TOP_SCHEDULE_H_
+#define TVM_TOP_SCHEDULE_H_
+
+#include <tvm/expr.h>
+#include <tvm/top/tensor.h>
+#include <tvm/top/tensor_intrin.h>
+
+
+#include <string>
+#include <unordered_map>
+
+
+namespace tvm {
+namespace top {
+// Node container for Stage
+class StageNode;
+// Node container for Schedule
+class ScheduleNode;
+// Node container for IterVarRelation
+class IterVarRelationNode;
+// Attribute of itervar.
+class IterVarAttrNode;
+
+/*! \brief the attachment type */
+enum AttachType : int {
+  kGroupRoot = 1,
+  kInline = 2,
+  kInlinedAlready = 3,
+  kScope = 4,
+  kScanUpdate = 5
+};
+
+/*! \brief Stage, contains scheduling for a stage of computation. */
+class Stage : public ObjectRef {
+ public:
+  Stage() {}
+  explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
+  /*!
+   * \brief create a new schedule for op.
+   * \param op The operator in the schedule
+   */
+  explicit Stage(Operation op);
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const StageNode* operator->() const;
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline StageNode* operator->();
+  /*!
+   * \brief set the memory scope of the stage
+   * \param scope The memory scope.
+   */
+  TVM_DLL Stage& set_scope(std::string scope);  // NOLINT(*)
+  /*!
+   * \brief specify the schedule to be computed at the parent schedule's scope.
+   * \param parent The parent schedule.
+   * \param scope The iteration point to carry the schedule.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& compute_at(Stage parent, IterVar scope);   // NOLINT(*)
+  /*!
+   * \brief Compute the function inline.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& compute_inline();   // NOLINT(*)
+  /*!
+   * \brief Compute the function at group root.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& compute_root();  // NOLINT(*)
+  /*!
+   * \brief Bind the IterVar to thread index.
+   *
+   * \param ivar The IterVar to be bound.
+   * \param thread_ivar The thread axis to be bound.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
+  /*!
+   * \brief Set the predicate to determine whether a store to the array should be performed.
+   *  Use this when there are multiple threads performing the same store and we only
+   *  need one of them to do the store.
+   *
+   * \note This is a dangerous scheduling primitive that can change behavior of program.
+   *    Only do when we are certain that thare are duplicated stores.
+   * \param predicate The condition to be checked.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
+  /*!
+   * \brief Specify environment threads that launched around the group's scope.
+   *  This can only be used in group stage.
+   * \param threads The threads to be launched around the scope.
+   * \note Each thread can only appear in one env_threads.
+   *    This is a beta feature.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& env_threads(Array<IterVar> threads);
+  /*!
+   * \brief Split the parent by factor, generate
+   * \param parent The parent iteration domain.
+   * \param factor The split factor of the loop.
+   * \param p_outer The result outer domain
+   * \param p_inner The result inner domain.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
+  /*!
+   * \brief Split the iteration with given number of parts.
+   *
+   * \param parent The parent domain.
+   * \param nparts The number of parts in the outer domain.
+   * \param p_outer The result outer domain.
+   * \param p_inner The result inner domain.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner);   // NOLINT(*)
+  /*!
+   * \brief Fuse the inner outer domain to the target
+   * \param outer The outer domain to be fused.
+   * \param inner The inner domain to be fused
+   * \param p_target The result target domain.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target);  // NOLINT(*)
+  /*!
+   * \brief Fuse all the axes together into a single axis.
+   *
+   * \param axes All the axes to be fused.
+   * \param p_target The result target domain.
+   *
+   * \note axes can be an empty array,
+   *       in that case, a singleton IterVar is created and
+   *       inserted to the outermost loop.
+   *       The fuse of empty array is used to support zero-dimension tensors.
+   *
+   * \return reference to self.
+   */
+  TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target);  // NOLINT(*)
+  /*!
+   * \brief Reorder the iteration
+   * \param order The order of iteration variable.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& reorder(const Array<IterVar>& order);   // NOLINT(*)
+  /*!
+   * \brief Perform tiling on two dimensions
+   *  The final loop order from outmost to inner most are
+   *  [x_outer, y_outer, x_inner, y_inner]
+   *
+   * \param x_parent The original x dimension
+   * \param y_parent The original y dimension
+   * \param x_factor The stride factor on x axis
+   * \param y_factor The stride factor on y axis
+   * \param p_x_outer Outer axis of x dimension
+   * \param p_y_outer Outer axis of y dimension
+   * \param p_x_inner Inner axis of x dimension
+   * \param p_y_inner Inner axis of y dimension
+   * \return reference to self.
+   */
+  TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent,   // NOLINT(*)
+                     PrimExpr x_factor, PrimExpr y_factor,
+                     IterVar* p_x_outer, IterVar* p_y_outer,
+                     IterVar* p_x_inner, IterVar* p_y_inner);
+  /*!
+   * \brief Vectorize iteration.
+   * \param var The axis to be vectorized.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& vectorize(IterVar var);   // NOLINT(*)
+  /*!
+   * \brief Replace computation of the current stage by tensor intrinsic f.
+   * \param var The axis marks beginning of tensorization.
+   *  Every operations inside the axis(include axis itself is tensorized).
+   * \param f The Tensor compute intrinsics.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f);   // NOLINT(*)
+  /*!
+   * \brief Unroll iteration.
+   * \param var The axis to be unrolled.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& unroll(IterVar var);   // NOLINT(*)
+  /*!
+   * \brief Parallelize iteration.
+   * \param var The axis to be parallelized.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& parallel(IterVar var);   // NOLINT(*)
+  /*!
+   * \brief Annotate the iteration with pragma
+   *
+   * \param var The axis to be parallelized.
+   * \param pragma_type The pragma type.
+   * \param pragma_value The pragma value
+   *
+   * \return reference to self.
+   */
+  TVM_DLL Stage& pragma(IterVar var,
+                       const std::string& pragma_type,
+                       const PrimExpr& pragma_value = PrimExpr());   // NOLINT(*)
+  /*!
+   * \brief Fetch data in advance.
+   * \param domain the tensor to be prefetched
+   * \param var the iteration point at which to apply prefetching
+   * \param offset the number of iterations be to fetched in advance
+   * \return reference to self
+   */
+  TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*)
+  /*!
+   * \brief Set alignment requirement for specific dimension.
+   *
+   *  Such that stride[axis] == k * factor + offset for some k.
+   *
+   * \param axis The dimension to be specified for alignment.
+   * \param factor The factor multiple of alignment
+   * \param offset The required offset factor.
+   * \return reference to self
+   */
+  TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
+  /*!
+   * \brief Compute current stage with double buffering.
+   * \return reference to self.
+   */
+  TVM_DLL Stage& double_buffer();   // NOLINT(*)
+  /*!
+   * \brief Schedule for OpenGL fragment shader.
+   * \return reference to self.
+   */
+  Stage& opengl(); // NOLINT(*)
+  /*!
+   * \brief whether the stage has been scheduled.
+   * \return whether the stage has been scheduled.
+   */
+  bool is_scheduled() const;
+  /*!
+   * \brief Get attachment spec of current stage.
+   *  If the stage compute at Group root, this function
+   *  will traverse the group function to get the
+   *  final spec from the group.
+   * \return A stage representing the attach spec of the group.
+   */
+  Stage GetAttachSpec() const;
+  // declare container type
+  using ContainerType = StageNode;
+};
+
+/*!
+ * \brief Global schedule container
+ *  For operations and all the operations they depend on.
+ *  The schedule per Operation is named as stage.
+ */
+class Schedule : public ObjectRef {
+ public:
+  Schedule() {}
+  explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
+  /*!
+   * \brief Get a copy of current schedule.
+   * \return The copied schedule.
+   */
+  Schedule copy() const;
+  /*!
+   * \brief Get the stage corresponds to the op
+   * \param op The operation.
+   */
+  TVM_DLL Stage operator[](const Operation& op);
+  /*!
+   * \brief Short hand for getting the stage of tensor's operation.
+   * \param tensor The tensor
+   * \return The stage corresponding to the tensor's op
+   */
+  TVM_DLL Stage operator[](const Tensor& tensor) {
+    return this->operator[](tensor->op);
+  }
+  /*!
+   * \brief Create a new stage group for all intermediate
+   *  operations between inputs and outputs.
+   *
+   * \param outputs The output boundary of the group.
+   * \param inputs The input boundary of the group.
+   * \param include_inputs Whether include inputs if they are reachable from outputs.
+   * \return The new grouped stage.
+   */
+  TVM_DLL Stage create_group(const Array<Tensor>& outputs,
+                     const Array<Tensor>& inputs,
+                     bool include_inputs = false);
+  /*!
+   * \brief create a cache read of original tensor for readers.
+   *  This will mutate the body of the readers.
+   *  A new stage will be created for the tensor.
+   * \param tensor The tensor cached.
+   * \param scope The scope of the cache.
+   * \param readers The readers to redirect to the tensor.
+   * \return The created tensor.
+   */
+  TVM_DLL Tensor cache_read(const Tensor& tensor,
+                    const std::string& scope,
+                    const Array<Operation>& readers);
+  /*!
+   * \brief Create a cache write tensor for producing tensor.
+   *  The the tensor will take over body of original tensor op.
+   *
+   *  This function can be used to do data layout transformation.
+   *  If there is a split/fuse/reorder on the data parallel axis of tensor
+   *  before cache_write is called. The intermediate cache stores
+   *  the data in the layout as the iteration order of leave axis.
+   *  The data will be transformed back to the original layout in the original tensor.
+   *  User can further call compute_inline to inline the original layout and keep
+   *  the data stored in the transformed layout.
+   *
+   * \param tensor The tensors to be produced.
+   * \param scope The scope of the storage.
+   * \return The created tensor.
+   */
+  TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
+  /*!
+   * \brief Create a cache write tensor for producing tensor.
+   *  The the tensor will take over body of original tensor op.
+   *
+   *  This function can be used to do data layout transformation.
+   *  If there is a split/fuse/reorder on the data parallel axis of tensor
+   *  before cache_write is called. The intermediate cache stores
+   *  the data in the layout as the iteration order of leave axis.
+   *  The data will be transformed back to the original layout in the original tensor.
+   *  User can further call compute_inline to inline the original layout and keep
+   *  the data stored in the transformed layout.
+   *
+   * \param tensor The tensor to be produced.
+   * \param scope The scope of the storage.
+   * \return The created tensor.
+   */
+  TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
+  /*!
+   * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
+   * This will create a new stage that generated the new tensor with axis
+   * as the first dimension. The tensor's body will be rewritten as a reduction
+   * over the factored tensor.
+   *
+   *  P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
+   *
+   * \param tensor The tensor to be factored.
+   * \param axis The reduction axis in tensor's schedule to be factored.
+   * \param factor_axis The position where the new axis is placed.
+   * \return The created factored tensors.
+   */
+  TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
+                        const IterVar& axis,
+                        int factor_axis = 0);
+  /*!
+   * \brief Normalize the schedule.
+   *  This is needed before bound inference.
+   *  Insert necessary RebaseNode to make sure all leaf_iter_vars
+   *  are in form [0, extent)
+   *
+   * \return A normalized schedule, can be same as current one.
+   */
+  Schedule normalize();
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const ScheduleNode* operator->() const;
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline ScheduleNode* operator->();
+  // declare container type
+  using ContainerType = ScheduleNode;
+};
+
+/*!
+ * \brief The schedule relation between IterVars
+ *  can be Split, Fuse.
+ */
+class IterVarRelation : public ObjectRef {
+ public:
+  IterVarRelation() {}
+  explicit IterVarRelation(ObjectPtr<Object> n) : ObjectRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const IterVarRelationNode* operator->() const;
+};
+
+/*!
+ * \brief Additional scheduable attributes about IterVar.
+ */
+class IterVarAttr : public ObjectRef {
+ public:
+  IterVarAttr() {}
+  explicit IterVarAttr(ObjectPtr<Object> n) : ObjectRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const IterVarAttrNode* operator->() const;
+};
+
+/*!
+ * \brief represents a stage.
+ *
+ *  relations form a Directed acylic hypergraph in bipartite manner.
+ *  With each node is represented by a IterVar,
+ *  and each hyper-edge is represented by a IterVarRelation.
+ *  The relations connects the IterVars in the graph.
+ *
+ *  Besides typical stage that corresponds to operations.
+ *  There is also group stage, which groups stages together.
+ *  Each stage's group(given by group) represent an constraint,
+ *  the stage can only be attached to stages within the group.
+ *
+ *  The group stage node can be attached to IterVars as in normal stage.
+ */
+class StageNode : public Object {
+ public:
+  /*!
+   * \brief The operation of stage, can be different from original op.
+   *  If it is null, then this stage is a group stage.
+   */
+  Operation op;
+  /*!
+   * \brief The original operator.
+   *  The op field can change during schedule to alternate the dataflow,
+   *  while origin_op remains fixed.
+   */
+  Operation origin_op;
+  /*! \brief All the nodes in the iter var */
+  Array<IterVar> all_iter_vars;
+  /*! \brief The current active leaf iter vars in the stage. */
+  Array<IterVar> leaf_iter_vars;
+  /*!
+   * \brief Specify threads to be launched at the stage.
+   *  This is only valid for composite ops such as Scan.
+   * \note Experimental primitive: used for thread persistence.
+   */
+  Array<IterVar> env_threads;
+  /*!
+   * \brief The predicate under which store can happen
+   *  Use this when there can be duplicated threads doing the same store.
+   * \note Experimental primitive: used by cross thread-reduction.
+   */
+  PrimExpr store_predicate;
+  /*! \brief The relation bwteen of IterVars */
+  Array<IterVarRelation> relations;
+  /*! \brief additional attributes about iter var. */
+  Map<IterVar, IterVarAttr> iter_var_attrs;
+  /*! \brief The attachment type of the schedule */
+  AttachType attach_type{kGroupRoot};
+  /*! \brief The attach point of this schedule. */
+  IterVar attach_ivar;
+  /*! \brief The stage this node attaches to */
+  Stage attach_stage;
+  /*! \brief The thread storage scope level of the stage */
+  std::string scope;
+  /*! \brief Whether this is an output stage */
+  bool is_output{false};
+  /*! \brief Whether this is an OpenGL stage */
+  bool is_opengl{false};
+  /*! \brief Whether apply double buffer optimization to this stage */
+  bool double_buffer{false};
+  /*!
+   * \brief The parent group of the current stage.
+   *  The stage cannot be assigned to stages outside the group.
+   */
+  Stage group;
+  /*! \brief Number of direct child stages, only used for group stage.*/
+  int num_child_stages{0};
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("op", &op);
+    v->Visit("origin_op", &origin_op);
+    v->Visit("all_iter_vars", &all_iter_vars);
+    v->Visit("leaf_iter_vars", &leaf_iter_vars);
+    v->Visit("env_threads", &env_threads);
+    v->Visit("relations", &relations);
+    v->Visit("iter_var_attrs", &iter_var_attrs);
+    v->Visit("attach_type", &attach_type);
+    v->Visit("attach_ivar", &attach_ivar);
+    v->Visit("attach_stage", &attach_stage);
+    v->Visit("scope", &scope);
+    v->Visit("is_output", &is_output);
+    v->Visit("is_opengl", &is_opengl);
+    v->Visit("double_buffer", &double_buffer);
+    v->Visit("group", &group);
+    v->Visit("num_child_stages", &num_child_stages);
+  }
+
+  static constexpr const char* _type_key = "Stage";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
+};
+
+/*! \brief node container for schedule */
+class ScheduleNode : public Object {
+ public:
+  /*! \brief The output operations in original data flow graph */
+  Array<Operation> outputs;
+  /*!
+   * \brief list of all stages for ops.
+   * The stages are sorted in dependency order.
+   */
+  Array<Stage> stages;
+  /*!
+   * \brief List of all stage groups.
+   */
+  Array<Stage> groups;
+  /*! \brief map of original operation to the stages */
+  Map<Operation, Stage> stage_map;
+  /*!
+   * \brief Internal stage map to map internal ops to stages.
+   *  This is created on demand and can be invalidated.
+   */
+  std::unordered_map<const Object*, Stage> op2stage_cache_;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("outputs", &outputs);
+    v->Visit("stages", &stages);
+    v->Visit("groups", &groups);
+    v->Visit("stage_map", &stage_map);
+  }
+
+  /*! \brief Initialize temp cache. */
+  void InitCache();
+  /*! \brief Invalidate temp cache. */
+  void InvalidateCache();
+
+  /*!
+   * \brief Check if the schedule contains an Operation.
+   * \param op The candidate Operation.
+   * \return true if the schedule has the Operation. Otherwise, false.
+   */
+  TVM_DLL bool Contain(const Operation& op) const;
+
+  /*!
+   * \brief Check if the schedule contains a Tensor.
+   * \param tensor The candidate tensor.
+   * \return true if the schedule has the tensor. Otherwise, false.
+   */
+  TVM_DLL bool Contain(const Tensor& tensor) const {
+    return Contain(tensor->op);
+  }
+
+  /*!
+   * \brief Create a schedule for array of ops(and their dependencies).
+   * \param ops The ops to be scheduled.
+   * \return sch The created Schedule.
+   */
+  TVM_DLL static Schedule make(Array<Operation> ops);
+
+  static constexpr const char* _type_key = "Schedule";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
+};
+
+/*!
+ * \brief Create a schedule for array of ops(and their dependencies).
+ * \param ops The ops to be scheduled.
+ * \return sch The created Schedule.
+ */
+inline Schedule create_schedule(Array<Operation> ops) {
+  return ScheduleNode::make(ops);
+}
+
+/*! \brief node container for IterVar attr */
+class IterVarAttrNode : public Object {
+ public:
+  /*! \brief The iteration type. */
+  IterVarType iter_type{kDataPar};
+  /*! \brief The thread this iter Var binds, can be null */
+  IterVar bind_thread;
+  /*! \brief List of tensor to be prefetched in this loop */
+  Array<Tensor> prefetch_data;
+  /*! \brief The offset used in each prefetch */
+  Array<PrimExpr> prefetch_offset;
+  /*!
+   * \brief Tensor intrinsic used in tensorization,
+   *   when the axis is marked as Tensorized
+   */
+  TensorIntrin tensor_intrin;
+  /*! \brief Alignment factor of buffer dimension */
+  int dim_align_factor{0};
+  /*! \brief Alignment offset of buffer dimension */
+  int dim_align_offset{0};
+  /*!
+   * \brief Additional pragma keys, array of StringImm
+   */
+  Array<PrimExpr> pragma_keys;
+  /*!
+   * \brief Additional values of pragma, if any
+   */
+  Array<PrimExpr> pragma_values;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("iter_type", &iter_type);
+    v->Visit("bind_thread", &bind_thread);
+    v->Visit("prefetch_data", &prefetch_data);
+    v->Visit("prefetch_offset", &prefetch_offset);
+    v->Visit("tensor_intrin", &tensor_intrin);
+    v->Visit("dim_align_factor", &dim_align_factor);
+    v->Visit("dim_align_offset", &dim_align_offset);
+    v->Visit("pragma_keys", &pragma_keys);
+    v->Visit("pragma_values", &pragma_values);
+  }
+
+  static constexpr const char* _type_key = "IterVarAttr";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object);
+};
+
+/*! \brief base node of iteration var */
+class IterVarRelationNode : public Object {
+ public:
+  static constexpr const char* _type_key = "IterVarRelation";
+  TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object);
+};
+
+/*!
+ * \brief Split the parent domain into product of
+ *  outer and iter.
+ */
+class SplitNode : public IterVarRelationNode {
+ public:
+  /*! \brief The parent domain */
+  IterVar parent;
+  /*! \brief The outer domain */
+  IterVar outer;
+  /*! \brief The inner domain */
+  IterVar inner;
+  /*! \brief The split factor */
+  PrimExpr factor;
+  /*! \brief Number of parts, only factor or nparts can be given */
+  PrimExpr nparts;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("parent", &parent);
+    v->Visit("outer", &outer);
+    v->Visit("inner", &inner);
+    v->Visit("factor", &factor);
+    v->Visit("nparts", &nparts);
+  }
+
+  static IterVarRelation make(IterVar parent,
+                              IterVar outer,
+                              IterVar inner,
+                              PrimExpr factor,
+                              PrimExpr nparts);
+
+  static constexpr const char* _type_key = "Split";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
+};
+
+/*!
+ * \brief Fuse two domains into one domain.
+ */
+class FuseNode : public IterVarRelationNode {
+ public:
+  /*! \brief The outer domain */
+  IterVar outer;
+  /*! \brief The inner domain */
+  IterVar inner;
+  /*! \brief The target domain */
+  IterVar fused;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("outer", &outer);
+    v->Visit("inner", &inner);
+    v->Visit("fused", &fused);
+  }
+
+  static IterVarRelation make(
+      IterVar outer, IterVar inner, IterVar fused);
+
+  static constexpr const char* _type_key = "Fuse";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
+};
+
+/*!
+ * \brief Rebase the iteration to make min to be 0.
+ *  This is useful to normalize the Schedule
+ *  to make every leaf variable's min to be 0.
+ */
+class RebaseNode : public IterVarRelationNode {
+ public:
+  /*! \brief The parent domain */
+  IterVar parent;
+  /*! \brief The inner domain */
+  IterVar rebased;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("parent", &parent);
+    v->Visit("rebased", &rebased);
+  }
+
+  static IterVarRelation make(IterVar parent, IterVar rebased);
+
+  static constexpr const char* _type_key = "Rebase";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
+};
+
+
+/*!
+ * \brief Singleton iterator [0, 1)
+ */
+class SingletonNode : public IterVarRelationNode {
+ public:
+  /*! \brief The singleton iterator */
+  IterVar iter;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("iter", &iter);
+  }
+
+  static IterVarRelation make(IterVar iter);
+
+  static constexpr const char* _type_key = "Singleton";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
+};
+
+
+// implementations
+inline const StageNode* Stage::operator->() const {
+  return static_cast<const StageNode*>(get());
+}
+inline StageNode* Stage::operator->() {
+  return static_cast<StageNode*>(get_mutable());
+}
+
+inline const ScheduleNode* Schedule::operator->() const {
+  return static_cast<const ScheduleNode*>(get());
+}
+inline ScheduleNode* Schedule::operator->() {
+  return static_cast<ScheduleNode*>(get_mutable());
+}
+
+inline const IterVarRelationNode* IterVarRelation::operator->() const {
+  return static_cast<const IterVarRelationNode*>(get());
+}
+
+inline const IterVarAttrNode* IterVarAttr::operator->() const {
+  return static_cast<const IterVarAttrNode*>(get());
+}
+}  // namespace top
+}  // namespace tvm
+#endif  // TVM_TOP_SCHEDULE_H_
diff --git a/include/tvm/top/schedule_pass.h b/include/tvm/top/schedule_pass.h
new file mode 100644 (file)
index 0000000..eacc9cd
--- /dev/null
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/top/schedule_pass.h
+ * \brief  Collection of Schedule pass functions.
+ *
+ *  These passes works on the schedule hyper-graph
+ *  and infers information such as bounds, check conditions
+ *  read/write dependencies between the IterVar
+ */
+#ifndef TVM_TOP_SCHEDULE_PASS_H_
+#define TVM_TOP_SCHEDULE_PASS_H_
+
+#include <tvm/top/schedule.h>
+
+namespace tvm {
+namespace top {
+
+/*!
+ * \brief Infer the bound of all iteration variables relates to the schedule.
+ *
+ * \param sch The root schedule to infer all the bounds.
+ * \return the result bound of the iteration Variable
+ */
+Map<IterVar, Range> InferBound(const Schedule& sch);
+
+/*!
+ * \brief Schedule s' dependent operations.
+ *
+ * \param s The schedule to be realized
+ * \param dom_map The domain of each iter vars.
+ * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 during lowering.
+ *                                This is a debug feature for dataflow/axis analysis.
+ *                                Note: If this is true, The lowered IR may be incorrect,
+ *                                because we will also delete the init part of reduction
+ * \return the result Stmt
+ */
+Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
+
+/*!
+ * \brief To automatically inline the element-wise operations.
+ *
+ * \param sch The schedule to be inlined.
+ */
+void AutoInlineElemWise(Schedule sch);
+
+/*!
+ * \brief To automatically inline operations with injective writes
+ *   (i.e. writes without reduction or sequential loops). Note
+ *   that in this case, guarantees about contiguity, transpose, stride,
+ *   alignemnt and memory footprint in general do not hold.
+ *
+ * \param sch The schedule to be inlined.
+ */
+TVM_DLL void AutoInlineInjective(Schedule sch);
+
+}  // namespace top
+}  // namespace tvm
+#endif  // TVM_TOP_SCHEDULE_PASS_H_
diff --git a/include/tvm/top/tensor.h b/include/tvm/top/tensor.h
new file mode 100644 (file)
index 0000000..bdfbbeb
--- /dev/null
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/top/tensor.h
+ * \brief Dataflow tensor object
+ */
+#ifndef TVM_TOP_TENSOR_H_
+#define TVM_TOP_TENSOR_H_
+
+#include <tvm/node/container.h>
+#include <tvm/arith/bound.h>
+#include <tvm/expr.h>
+#include <tvm/expr_operator.h>
+
+#include <string>
+#include <vector>
+#include <utility>
+#include <type_traits>
+
+
+
+namespace tvm {
+namespace top {
+
+// Internal node container of Tensor
+class TensorNode;
+// internal node container for Operation
+class OperationNode;
+
+/*!
+ * \brief Tensor structure representing a possible input,
+ *  or intermediate computation result.
+ */
+class Tensor : public ObjectRef {
+ public:
+  /*! \brief default constructor, used internally */
+  Tensor() {}
+  explicit Tensor(ObjectPtr<Object> n) : ObjectRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const TensorNode* operator->() const;
+  /*!
+   * \brief check if two tensors equals each other.
+   * \param other tensor to be checked.
+   * \return whether the two tensors equals each other.
+   */
+  inline bool operator==(const Tensor& other) const;
+  /*!
+   * \brief check if two tensors are different.
+   * \param other tensor to be checked.
+   * \return whether the two tensors are different.
+   */
+  inline bool operator!=(const Tensor& other) const;
+  /*! \return The dimension of the tensor */
+  inline size_t ndim() const;
+  /*!
+   * \brief Take elements from the tensor
+   * \param args The indices
+   * \return the result expression representing tensor read.
+   */
+  template<typename... Args>
+  inline PrimExpr operator()(Args&& ...args) const {
+    Array<PrimExpr> indices{std::forward<Args>(args)...};
+    return operator()(indices);
+  }
+  /*!
+   * \brief Take elements from the tensor
+   * \param indices the indices.
+   * \return the result expression representing tensor read.
+   */
+  TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const;
+  /*!
+   * \brief Take elements from the tensor
+   * \param indices the indices.
+   * \return the result expression representing tensor read.
+   */
+  TVM_DLL PrimExpr operator()(Array<Var> indices) const;
+  /*!
+   * \brief data structure to represent a slice that fixes first k coordinates.
+   *  This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
+   */
+  class Slice {
+   public:
+    // construct via tensor and indices
+    Slice(const Tensor& tensor, std::vector<PrimExpr> indices)
+        : tensor_(tensor), indices_(indices) {}
+    /*!
+     * \brief get i-th slice from the current slice.
+     * \param i the index of the coordinate
+     * \return the subsequent slice.
+     */
+    inline Slice operator[](PrimExpr i) {
+      std::vector<PrimExpr> other = indices_;
+      other.emplace_back(i);
+      return Slice(tensor_, other);
+    }
+    /*!
+     * \brief Convert slice to expression.
+     *  This is only valid when all the coordinates are fully specified.
+     * \return the corresponding expression of this slice.
+     */
+    inline operator PrimExpr() const {
+      return tensor_(indices_);
+    }
+
+   private:
+    const Tensor& tensor_;
+    std::vector<PrimExpr> indices_;
+  };
+  /*!
+   * \brief get i-th slice from the current Tensor.
+   * \param i the index of the coordinate
+   * \return the subsequent slice.
+   */
+  inline Slice operator[](PrimExpr i) const {
+    return Slice(*this, {i});
+  }
+  /*! \brief specify container node */
+  using ContainerType = TensorNode;
+};
+
+/*! \brief Operation that produces tensors */
+class Operation : public ir::FunctionRef {
+ public:
+  /*! \brief default constructor  */
+  Operation() {}
+  explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const OperationNode* operator->() const;
+  /*!
+   * \brief get the i-th output of the operation.
+   * \param i the output index.
+   * \return The i-th output.
+   */
+  TVM_DLL Tensor output(size_t i) const;
+  /*! \brief specify container node */
+  using ContainerType = OperationNode;
+};
+
+/*! \brief Node to represent a tensor */
+class TensorNode : public Object {
+ public:
+  /*! \brief The shape of the tensor */
+  Array<PrimExpr> shape;
+  /*! \brief data type in the content of the tensor */
+  DataType dtype;
+  /*! \brief the source operation, can be None */
+  Operation op;
+  /*! \brief the output index from source operation */
+  int value_index{0};
+  /*! \brief constructor */
+  TensorNode() {}
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("shape", &shape);
+    v->Visit("dtype", &dtype);
+    v->Visit("op", &op);
+    v->Visit("value_index", &value_index);
+  }
+  TVM_DLL static Tensor make(Array<PrimExpr> shape,
+                             DataType dtype,
+                             Operation op,
+                             int value_index);
+
+  static constexpr const char* _type_key = "Tensor";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object);
+};
+
+
+// Implementations of inline functions
+inline const TensorNode* Tensor::operator->() const {
+  return static_cast<const TensorNode*>(get());
+}
+
+inline size_t Tensor::ndim() const {
+  return (*this)->shape.size();
+}
+
+inline bool Tensor::operator==(const Tensor& other) const {
+  if (get() == other.get()) return true;
+  if (get() == nullptr || other.get() == nullptr) return false;
+  if ((*this)->op.defined() || other->op.defined()) {
+    return (*this)->op == other->op &&
+        (*this)->value_index == other->value_index;
+  } else {
+    return false;
+  }
+}
+
+inline bool Tensor::operator!=(const Tensor& other) const {
+  return !(*this == other);
+}
+
+// macro to turn every operation of slice to expression
+#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op)                              \
+  inline PrimExpr operator Op (const Tensor::Slice& a) {           \
+    return Op a.operator PrimExpr() ;                              \
+  }                                                                     \
+
+#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op)                             \
+  template<typename T>                                                  \
+  inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \
+    return a.operator PrimExpr() Op b;                             \
+  }                                                                     \
+  template<typename T>                                                  \
+  inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) {  \
+    return a Op b.operator PrimExpr();                                \
+  }                                                                        \
+  inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
+    return a.operator PrimExpr() Op b.operator PrimExpr();                  \
+  }
+
+DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
+DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(!=);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(&&);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(||);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(>>);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(<<);
+DEFINE_OVERLOAD_SLICE_BINARY_OP(>);  // NOLINT(*)
+DEFINE_OVERLOAD_SLICE_BINARY_OP(<);  // NOLINT(*)
+
+}  // namespace top
+}  // namespace tvm
+
+namespace std {
+template <>
+struct hash<::tvm::top::Operation> : public ::tvm::ObjectHash {
+};
+
+template <>
+struct hash<::tvm::top::Tensor> {
+  std::size_t operator()(const ::tvm::top::Tensor& k) const {
+    ::tvm::ObjectHash hasher;
+    if (k.defined() && k->op.defined()) {
+      return hasher(k->op);
+    } else{
+      return hasher(k);
+    }
+  }
+};
+}  // namespace std
+#endif  // TVM_TOP_TENSOR_H_
diff --git a/include/tvm/top/tensor_intrin.h b/include/tvm/top/tensor_intrin.h
new file mode 100644 (file)
index 0000000..99eb885
--- /dev/null
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/top/tensor_intrin.h
+ * \brief Tensor intrinsic operations.
+ */
+#ifndef TVM_TOP_TENSOR_INTRIN_H_
+#define TVM_TOP_TENSOR_INTRIN_H_
+
+#include <tvm/top/tensor.h>
+#include <tvm/buffer.h>
+
+#include <string>
+
+
+namespace tvm {
+namespace top {
+
+// Internal node container of tensor intrinsics.
+class TensorIntrinNode;
+
+/*! \brief Tensor intrinsic node. */
+class TensorIntrin : public ObjectRef {
+ public:
+  TensorIntrin() {}
+  explicit TensorIntrin(ObjectPtr<Object> n) : ObjectRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const TensorIntrinNode* operator->() const;
+
+  /*! \brief specify container node */
+  using ContainerType = TensorIntrinNode;
+};
+
+/*! \brief Node to represent a Tensor intrinsic operator */
+class TensorIntrinNode : public Object {
+ public:
+  /*! \brief The name of the intrinsic */
+  std::string name;
+  /*! \brief The operation this intrinsics is carrying out */
+  Operation op;
+  /*! \brief List of inputs of operator, placeholder in postdfs order */
+  Array<Tensor> inputs;
+  /*!
+   * \brief Symbolic buffers of each output/input tensor
+   *  buffers[0:len(inputs)] are buffers of the inputs.
+   *  buffers[len(inputs):] are buffers of each output.
+   *
+   * \note When a field in Buffer is Var, it means we can be flexible
+   *  wrt that field and Var can occur in body.
+   *  When it is a constant, it means we can only take data in that shape.
+   */
+  Array<Buffer> buffers;
+  /*! \brief List of scalar variables, used in body. These placeholders
+   *  will be bound to expressions passed in when the TensorIntrin is called
+   * from a TensorComputeOp.
+   */
+  Array<Var> scalar_params;
+  /*! \brief The normal statement to execute the intrinsic */
+  Stmt body;
+  /*!
+   * \brief Special statement for reduction op, can be None
+   *  reset the value of output buffer to identity value.
+   */
+  Stmt reduce_init;
+  /*!
+   * \brief Special statement for reduction op, can be None
+   *  Reduce: do a reduction of current output buffer with the result.
+   */
+  Stmt reduce_update;
+  /*! \brief constructor */
+  TensorIntrinNode() {}
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("op", &op);
+    v->Visit("inputs", &inputs);
+    v->Visit("buffers", &buffers);
+    v->Visit("scalar_params", &scalar_params);
+    v->Visit("body", &body);
+    v->Visit("reduce_init", &reduce_init);
+    v->Visit("reduce_update", &reduce_update);
+  }
+
+  TVM_DLL static TensorIntrin make(std::string name,
+                                   Operation op,
+                                   Array<Tensor> inputs,
+                                   Array<Buffer> buffers,
+                                   Array<Var> scalar_params,
+                                   Stmt body,
+                                   Stmt reduce_init,
+                                   Stmt reduce_update);
+
+  static constexpr const char* _type_key = "TensorIntrin";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
+};
+
+inline const TensorIntrinNode* TensorIntrin::operator->() const {
+  return static_cast<const TensorIntrinNode*>(get());
+}
+
+// Internal node container of tensor intrinsic calling.
+class TensorIntrinCallNode;
+
+/*! \brief Tensor intrinsic calling node. */
+class TensorIntrinCall : public ObjectRef {
+ public:
+  TensorIntrinCall() {}
+  explicit TensorIntrinCall(ObjectPtr<Object> n) : ObjectRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const TensorIntrinCallNode* operator->() const;
+
+  /*! \brief specify container node */
+  using ContainerType = TensorIntrinCallNode;
+};
+
+class TensorIntrinCallNode : public Object {
+ public:
+  /*! \brief the tensor intrinsic */
+  TensorIntrin intrin;
+  /*! \brief input tensors of the intrinsic */
+  Array<Tensor> tensors;
+  /*! \brief regions of input tensors */
+  Array<Region> regions;
+
+
+  /*!
+   * \brief IterVar on each reduction axis, if the
+   * intrin will use the reduce axis
+   */
+  Array<IterVar> reduce_axis;
+
+  /*! \brief scalar expression inputs */
+  Array<PrimExpr> scalar_inputs;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("intrin", &intrin);
+    v->Visit("tensors", &tensors);
+    v->Visit("regions", &regions);
+    v->Visit("reduce_axis", &reduce_axis);
+    v->Visit("scalar_inputs", &scalar_inputs);
+  }
+  static TensorIntrinCall make(TensorIntrin intrin,
+                               Array<Tensor> tensors,
+                               Array<Region> regions,
+                               Array<IterVar> reduce_axis,
+                               Array<PrimExpr> scalar_inputs);
+
+  static constexpr const char* _type_key = "TensorIntrinCall";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object);
+};
+
+inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
+  return static_cast<const TensorIntrinCallNode*>(get());
+}
+
+}  // namespace top
+}  // namespace tvm
+#endif  // TVM_TOP_TENSOR_INTRIN_H_
index b0172b665eb8f6207770b3ce851d1d3b9bdef3ea..2de81416cdb053796d8dcf7faa4e100cc29923e6 100644 (file)
@@ -24,13 +24,12 @@ There can be internal header files within each module that sit in src.
 - support: Internal support utilities.
 - runtime: Minimum runtime related codes.
 - node: base infra for IR/AST nodes that is dialect independent.
-- api: API function registration.
-- lang: The definition of DSL related data structure.
 - arith: Arithmetic expression and set simplification.
-- op: The detail implementations about each operation(compute, scan, placeholder).
-- schedule: The operations on the schedule graph before converting to IR.
+- top: tensor operation DSL for compute and schedule.
+- relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks.
 - pass: The optimization pass on the IR structure.
 - codegen: The code generator.
 - autotvm: The auto-tuning module.
-- relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks.
 - contrib: Contrib extension libraries.
+- api: API function registration.
+- lang: The definition of DSL related data structure.
index b4020eb944ea3d14905caec9532b2e39e360ddf7..0c28d08a7f28eba4438ab2615b8b789d333b9c93 100644 (file)
@@ -31,7 +31,7 @@
 #include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 
-#include <tvm/tensor.h>
+#include <tvm/top/tensor.h>
 
 namespace tvm {
 namespace arith {
index 131ac363bb8b283e7e90436b73b592648448781b..4b74d02cdf9e6e90967cde40dbfe0595c8840494 100644 (file)
@@ -23,7 +23,7 @@
  */
 #include <dmlc/memory_io.h>
 #include <tvm/expr.h>
-#include <tvm/tensor.h>
+#include <tvm/top/tensor.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 
index 2c7eb3cc0270d50bc0908a402096dc806059eb76..89c2c53105073847f220ed28e04c7e014473813d 100644 (file)
  */
 #include <tvm/expr.h>
 #include <tvm/ir.h>
-#include <tvm/tensor.h>
-#include <tvm/operation.h>
+#include <tvm/top/tensor.h>
+#include <tvm/top/operation.h>
 #include <tvm/buffer.h>
-#include <tvm/schedule.h>
+#include <tvm/top/schedule.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 
@@ -274,6 +274,7 @@ TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardShape")
 TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape")
 .set_body_method(&BijectiveLayout::BackwardShape);
 
+namespace top {
 TVM_REGISTER_GLOBAL("_Tensor")
 .set_body_typed(TensorNode::make);
 
@@ -441,6 +442,7 @@ TVM_REGISTER_GLOBAL("_ScheduleCacheWrite")
 
 TVM_REGISTER_GLOBAL("_ScheduleRFactor")
 .set_body_method(&Schedule::rfactor);
+}  // namespace top
 
 TVM_REGISTER_GLOBAL("_CommReducerCombine")
 .set_body_method<ir::CommReducer>(&ir::CommReducerNode::operator());
index 639855c2a441c86bf6edad7c2f13567221024ea2..a822cc1402179f544e0cd97e2dd4b22889377d6a 100644 (file)
@@ -96,7 +96,9 @@ TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten")
 
 TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
 .set_body_typed
-  ([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) {
+  ([](const Stmt& stmt,
+      const top::Schedule& schedule,
+      const Map<top::Tensor, Buffer>& extern_buffer) {
       return RewriteForTensorCore(stmt, schedule, extern_buffer);
   });
 
index a7c27e4c8e474c0bc5459299dd52a05b9f2ff814..7aa305fe0f4ae9713aed1d3c62dff170e9ac2235 100644 (file)
  * \file api_schedule.cc
  */
 #include <tvm/expr.h>
-#include <tvm/tensor.h>
-#include <tvm/schedule.h>
-#include <tvm/schedule_pass.h>
+#include <tvm/top/tensor.h>
+#include <tvm/top/schedule.h>
+#include <tvm/top/schedule_pass.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 
-#include "../schedule/graph.h"
+#include "../top/schedule/graph.h"
 
 namespace tvm {
-namespace schedule {
+namespace top {
 
 TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
 .set_body_typed(AutoInlineElemWise);
@@ -60,5 +60,5 @@ REGISTER_SCHEDULE_PASS(CreateAttachPath);
 REGISTER_SCHEDULE_PASS(ScanGetBody);
 REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis);
 
-}  // namespace schedule
+}  // namespace top
 }  // namespace tvm
index 0bc83ea148a4308d20e5562e6ac48475dba85676..957a034d427f2fe3ae28e5f49cd8191916d46795 100644 (file)
@@ -22,7 +22,7 @@
  * \file api_test.cc
  */
 #include <tvm/expr.h>
-#include <tvm/tensor.h>
+#include <tvm/top/tensor.h>
 #include <tvm/ir/attrs.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/ir/env_func.h>
index 3889cd2cf9182119814b1d8304d3351a82d4e745..6e665c8bb72b53297164f2136fe00af704502166 100644 (file)
@@ -24,7 +24,7 @@
 #include <tvm/expr.h>
 #include <tvm/ir_pass.h>
 #include <tvm/ir_functor_ext.h>
-#include <tvm/tensor.h>
+#include <tvm/top/tensor.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 
@@ -40,7 +40,7 @@ using namespace ir;
 // Find Read region of the tensor in the stmt.
 class FuncTouchedDomain final : public StmtExprVisitor {
  public:
-  FuncTouchedDomain(const Tensor &tensor, bool consider_calls, bool consider_provides)
+  FuncTouchedDomain(const top::Tensor &tensor, bool consider_calls, bool consider_provides)
     : tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides)  {}
 
   Domain Find(const Stmt& stmt) {
@@ -108,13 +108,16 @@ class FuncTouchedDomain final : public StmtExprVisitor {
     }
   }
 
-  const Tensor &tensor_;
+  const top::Tensor &tensor_;
   bool consider_calls_, consider_provides_;
   std::vector<std::vector<IntSet> > bounds_;
   std::unordered_map<const VarNode*, IntSet> dom_map_;
 };
 
-Domain DomainTouched(Stmt stmt, const Tensor &tensor, bool consider_calls, bool consider_provides) {
+Domain DomainTouched(Stmt stmt,
+                     const top::Tensor &tensor,
+                     bool consider_calls,
+                     bool consider_provides) {
   return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt);
 }
 
index 771583b66a74b010a2f084fb13b0db72de0e0328..cfb75c4e68efa54d1aed1924f73717eb2c7d47c5 100644 (file)
@@ -23,7 +23,7 @@
  */
 #include <dmlc/thread_local.h>
 #include <tvm/build_module.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/ir_pass.h>
 #include <tvm/codegen.h>
 #include <tvm/runtime/registry.h>
@@ -87,10 +87,10 @@ Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape,
     data_alignment, offset_factor, buffer_type);
 }
 
-void GetBinds(const Array<Tensor>& args,
+void GetBinds(const Array<top::Tensor>& args,
               bool compact,
-              const std::unordered_map<Tensor, Buffer>& binds,
-              Map<Tensor, Buffer>* out_binds,
+              const std::unordered_map<top::Tensor, Buffer>& binds,
+              Map<top::Tensor, Buffer>* out_binds,
               Array<ObjectRef>* out_arg_list,
               const BuildConfig& config) {
   *out_binds = binds;
@@ -117,21 +117,21 @@ void GetBinds(const Array<Tensor>& args,
 * \param config The build configuration.
 * \return The built Stmt.
 */
-Stmt BuildStmt(Schedule sch,
-               const Array<Tensor>& args,
-               const std::unordered_map<Tensor, Buffer>& binds,
+Stmt BuildStmt(top::Schedule sch,
+               const Array<top::Tensor>& args,
+               const std::unordered_map<top::Tensor, Buffer>& binds,
                bool loop_partition,
                Array<ObjectRef> *out_arg_list,
                const BuildConfig& config) {
   sch = sch.normalize();
 
   // Phase 0
-  auto bounds = schedule::InferBound(sch);
-  auto stmt = schedule::ScheduleOps(sch, bounds, false);
+  auto bounds = top::InferBound(sch);
+  auto stmt = top::ScheduleOps(sch, bounds, false);
   stmt = ir::InjectPrefetch(stmt);
 
   bool compact = ir::VerifyCompactBuffer(stmt);
-  Map<Tensor, Buffer> out_binds;
+  Map<top::Tensor, Buffer> out_binds;
   GetBinds(args, compact, binds, &out_binds, out_arg_list, config);
 
   // Phase 1
@@ -165,10 +165,10 @@ Stmt BuildStmt(Schedule sch,
   return stmt;
 }
 
-Array<LoweredFunc> lower(Schedule sch,
-                         const Array<Tensor>& args,
+Array<LoweredFunc> lower(top::Schedule sch,
+                         const Array<top::Tensor>& args,
                          const std::string& name,
-                         const std::unordered_map<Tensor, Buffer>& binds,
+                         const std::unordered_map<top::Tensor, Buffer>& binds,
                          const BuildConfig& config) {
   Array<ObjectRef> out_arg_list;
   auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
index 33bd0efae8a450c4a61469024c068ad7c2a11d75..f5ba9abf12447fd23b93045d4bd3dfb3d2183037 100644 (file)
@@ -28,7 +28,7 @@
 #include <tvm/ir_functor_ext.h>
 #include <tvm/codegen.h>
 #include <tvm/lowered_func.h>
-#include <tvm/schedule.h>
+#include <tvm/top/schedule.h>
 #include <map>
 #include <string>
 #include <unordered_map>
@@ -38,6 +38,7 @@
 namespace tvm {
 namespace contrib {
 
+using namespace top;
 using namespace ir;
 /*!
  * \brief A base class to generate Hybrid Script.
diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc
deleted file mode 100644 (file)
index 35b4029..0000000
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tensor.cc
- */
-#include <tvm/tensor.h>
-#include <tvm/operation.h>
-#include <tvm/tensor_intrin.h>
-#include <memory>
-
-namespace tvm {
-
-// Tensor
-PrimExpr Tensor::operator()(Array<Var> indices) const {
-  Array<PrimExpr> arr(indices.begin(), indices.end());
-  return operator()(arr);
-}
-
-PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
-  using ir::CallNode;
-  if (ndim() != 0) {
-    CHECK_EQ(ndim(), indices.size())
-        << "Tensor dimension mismatch in read"
-        << "ndim = " << ndim() << ", indices.size=" << indices.size();
-  }
-  auto n = CallNode::make(
-      (*this)->dtype, (*this)->op->name, indices, CallNode::Halide,
-      (*this)->op, (*this)->value_index);
-  return n;
-}
-
-Tensor Operation::output(size_t i) const {
-  auto node = make_object<TensorNode>();
-  node->op = *this;
-  node->value_index = i;
-  node->dtype = (*this)->output_dtype(i);
-  node->shape = (*this)->output_shape(i);
-  return Tensor(node);
-}
-
-Tensor TensorNode::make(Array<PrimExpr> shape,
-                        DataType dtype,
-                        Operation op,
-                        int value_index) {
-  auto n = make_object<TensorNode>();
-  n->shape = std::move(shape);
-  n->dtype = dtype;
-  n->op = op;
-  n->value_index = value_index;
-  return Tensor(n);
-}
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TensorNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* t = static_cast<const TensorNode*>(node.get());
-    p->stream << "Tensor(shape=" << t->shape
-              << ", op.name=" << t->op->name << ')';
-  });
-
-TVM_REGISTER_NODE_TYPE(TensorNode);
-
-
-// TensorIntrin
-
-TensorIntrin TensorIntrinNode::make(std::string name,
-                                    Operation op,
-                                    Array<Tensor> inputs,
-                                    Array<Buffer> buffers,
-                                    Array<Var> scalar_params,
-                                    Stmt body,
-                                    Stmt reduce_init,
-                                    Stmt reduce_update) {
-  auto n = make_object<TensorIntrinNode>();
-  n->name = std::move(name);
-  n->op = std::move(op);
-  n->inputs = std::move(inputs);
-  n->buffers = std::move(buffers);
-  n->scalar_params = std::move(scalar_params);
-  n->body = std::move(body);
-  n->reduce_init = std::move(reduce_init);
-  n->reduce_update = std::move(reduce_update);
-  return TensorIntrin(n);
-}
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const TensorIntrinNode*>(node.get());
-    p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
-  });
-
-TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
-
-
-// TensorIntrinCall
-
-TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
-                                            Array<Tensor> tensors,
-                                            Array<Region> regions,
-                                            Array<IterVar> reduce_axis,
-                                            Array<PrimExpr> scalar_inputs) {
-  auto n = make_object<TensorIntrinCallNode>();
-  n->intrin = std::move(intrin);
-  n->tensors = std::move(tensors);
-  n->regions = std::move(regions);
-  n->reduce_axis = std::move(reduce_axis);
-  n->scalar_inputs = std::move(scalar_inputs);
-  return TensorIntrinCall(n);
-}
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
-    p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
-  });
-
-TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
-
-}  // namespace tvm
diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
deleted file mode 100644 (file)
index b82bab5..0000000
+++ /dev/null
@@ -1,640 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Compute Op.
- * \file compute_op.cc
- */
-#include <tvm/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <tvm/ir_functor_ext.h>
-#include <unordered_set>
-#include <string>
-#include <utility>
-#include "compute_op.h"
-#include "op_util.h"
-#include "../schedule/message_passing.h"
-#include "../arith/compute_expr.h"
-#include "../arith/interval_set.h"
-
-namespace tvm {
-
-using namespace ir;
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<ComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const ComputeOpNode*>(node.get());
-    p->stream << "compute(" << op->name << ", " << op << ")";
-});
-
-TVM_REGISTER_NODE_TYPE(ComputeOpNode);
-
-/// Verify if ComputeOp is valid with respect to Reduce operations.
-static void VerifyComputeOp(const ComputeOpNode *op);
-
-inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
-  return (a->combiner.same_as(b->combiner)) &&
-         (a->source.same_as(b->source)) &&
-         (a->axis.same_as(b->axis)) &&
-         (a->condition.same_as(b->condition));
-}
-
-int ComputeOpNode::num_outputs() const {
-  return body.size();
-}
-
-Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
-  if (reduce_axis.size() == 0) return axis;
-  Array<IterVar> ret = axis;
-  for (IterVar iv : reduce_axis) {
-    ret.push_back(iv);
-  }
-  return ret;
-}
-
-DataType ComputeOpNode::output_dtype(size_t idx) const {
-  CHECK_LT(idx, num_outputs());
-  return body[idx].dtype();
-}
-
-Array<PrimExpr> BaseComputeOpNode::output_shape(size_t idx) const {
-  CHECK_LT(idx, num_outputs());
-  // for now, all outputs of a BaseComputeOp have the same shape
-  Array<PrimExpr> shape;
-  for (const auto& ivar : this->axis) {
-    const Range& r = ivar->dom;
-    shape.push_back(r->extent);
-  }
-  return shape;
-}
-
-Tensor compute(Array<PrimExpr> shape,
-               FCompute fcompute,
-               std::string name,
-               std::string tag,
-               Map<std::string, ObjectRef> attrs) {
-  auto op_node = make_object<ComputeOpNode>();
-  // compute dimension.
-  size_t ndim = shape.size();
-  std::vector<IterVar> axis;
-  std::vector<Var> args;
-  for (size_t i = 0; i < ndim; ++i) {
-    std::ostringstream os;
-    os << "ax" << i;
-    axis.emplace_back(IterVarNode::make(
-        Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
-    args.push_back(axis.back()->var);
-  }
-
-  return ComputeOpNode::make(
-      name, tag, attrs, axis, {fcompute(args)}).output(0);
-}
-
-Array<Tensor> compute(Array<PrimExpr> shape,
-                      FBatchCompute fcompute,
-                      std::string name,
-                      std::string tag,
-                      Map<std::string, ObjectRef> attrs) {
-  auto op_node = make_object<ComputeOpNode>();
-  // compute dimension.
-  size_t ndim = shape.size();
-  std::vector<IterVar> axis;
-  std::vector<Var> args;
-  for (size_t i = 0; i < ndim; ++i) {
-    std::ostringstream os;
-    os << "ax" << i;
-    axis.emplace_back(IterVarNode::make(
-        Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
-    args.push_back(axis.back()->var);
-  }
-
-  Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
-  Array<Tensor> outputs;
-  for (int idx = 0; idx < op->num_outputs(); ++idx) {
-    outputs.push_back(op.output(idx));
-  }
-  return outputs;
-}
-
-Operation ComputeOpNode::make(std::string name,
-                              std::string tag,
-                              Map<std::string, ObjectRef> attrs,
-                              Array<IterVar> axis,
-                              Array<PrimExpr> body) {
-  if (!attrs.defined()) {
-    attrs = Map<std::string, ObjectRef>();
-  }
-  auto n = make_object<ComputeOpNode>();
-  n->name = std::move(name);
-  n->tag = std::move(tag);
-  n->attrs = std::move(attrs);
-  n->axis = std::move(axis);
-  n->body = std::move(body);
-  if (n->body[0]->IsInstance<ir::ReduceNode>()) {
-    const ir::ReduceNode* reduce = n->body[0].as<ir::ReduceNode>();
-    n->reduce_axis = reduce->axis;
-  }
-  VerifyComputeOp(n.get());
-  return Operation(n);
-}
-
-// The schedule related logics
-Array<Tensor> ComputeOpNode::InputTensors() const {
-  Array<Tensor> ret;
-  std::unordered_set<Tensor> visited;
-  for (auto& e : body) {
-    ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
-        const ir::CallNode *call = n.as<ir::CallNode>();
-        if (call != nullptr && call->func.defined()) {
-          Tensor t = Downcast<Operation>(call->func).output(call->value_index);
-          if (!visited.count(t)) {
-            ret.push_back(t);
-            visited.insert(t);
-          }
-        }
-      });
-  }
-  return ret;
-}
-
-Operation ComputeOpNode::ReplaceInputs(
-    const Operation& self,
-    const std::unordered_map<Tensor, Tensor>& rmap) const {
-  CHECK_EQ(self.operator->(), this);
-  VerifyComputeOp(this);
-  Array<PrimExpr> arr;
-  if (this->body[0]->IsInstance<ir::ReduceNode>()) {
-    // Specially handle reduce so the replaced op
-    // still share all the components
-    PrimExpr new_reduce = op::ReplaceTensor(this->body[0], rmap);
-    if (!new_reduce.same_as(this->body[0])) {
-      const ir::ReduceNode* r = new_reduce.as<ir::ReduceNode>();
-      for (size_t k = 0; k < this->body.size(); ++k) {
-        auto n = make_object<ir::ReduceNode>(*r);
-        n->value_index = static_cast<int>(k);
-        n->dtype = r->source[k].dtype();
-        arr.push_back(PrimExpr(n));
-      }
-    } else {
-      arr = this->body;
-    }
-  } else {
-    arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) {
-        return op::ReplaceTensor(e, rmap);
-      });
-  }
-  if (!arr.same_as(this->body)) {
-    return ComputeOpNode::make(
-        this->name, this->tag, this->attrs, this->axis, arr);
-  } else {
-    return self;
-  }
-}
-
-void ComputeOpNode::PropBoundToInputs(
-    const Operation& self,
-    arith::Analyzer* analyzer,
-    const std::unordered_map<const VarNode*, IntSet>& dom_map,
-    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
-  CHECK_EQ(self.operator->(), this);
-  auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
-    auto *call = n.as<ir::CallNode>();
-    if (call != nullptr && call->func.defined()) {
-      Tensor t = Downcast<Operation>(call->func).output(call->value_index);
-      if (t->op.defined() && out_dom_map->count(t)) {
-        TensorDom& dom = out_dom_map->at(t);
-        for (size_t i = 0; i < t.ndim(); ++i) {
-          // We assume that the value of the argument cannot be out of bounds (otherwise it is
-          // undefined behaviour), so we can intersect the estimated set of the argument with the
-          // range expected by the tensor. However, intersection may result in overly complex
-          // expressions, so we perform a more relaxed form of intersection.
-          IntSet arg_intset = EvalSet(call->args[i], dom_map);
-          const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
-          if (arg_interval) {
-            PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
-            PrimExpr shape_i_max_value = t->shape[i] - 1;
-            PrimExpr min_value = arg_interval->min_value;
-            PrimExpr max_value = arg_interval->max_value;
-            // Prefer the shape bounds only when we can prove they are tighter.
-            if (arith::is_neg_inf(min_value) ||
-                analyzer->CanProve(shape_i_min_value >= min_value)) {
-              min_value = shape_i_min_value;
-            }
-            if (arith::is_pos_inf(max_value) ||
-                analyzer->CanProve(shape_i_max_value <= max_value)) {
-              max_value = shape_i_max_value;
-            }
-            dom.data[i].push_back(IntSet::interval(min_value, max_value));
-          } else {
-            dom.data[i].push_back(arg_intset);
-          }
-        }
-      }
-    }
-  };
-  for (auto& e : body) ir::PostOrderVisit(e, fvisit);
-}
-
-void BaseComputeOpNode::GatherBound(
-    const Operation& self,
-    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
-    std::unordered_map<IterVar, Range>* out_dom_map) const {
-  CHECK_EQ(self.operator->(), this);
-  const TensorDom& tdom = tensor_dom.at(self.output(0));
-  for (size_t i = 0; i < this->axis.size(); ++i) {
-    Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
-    CHECK(!out_dom_map->count(this->axis[i]));
-    (*out_dom_map)[this->axis[i]] = r;
-  }
-  for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
-    CHECK(!out_dom_map->count(this->reduce_axis[i]));
-    (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
-  }
-}
-
-Stmt BaseComputeOpNode::BuildRealize(
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& realize_map,
-    const Stmt& body) const {
-  CHECK_EQ(stage->op.get(), this);
-  Region bounds;
-  for (IterVar iv : this->axis) {
-    bounds.push_back(realize_map.at(iv));
-  }
-  Stmt realize = body;
-  for (int i = this->num_outputs(); i > 0; --i) {
-    Tensor t = stage->op.output(i-1);
-    realize = ir::RealizeNode::make(t->op, t->value_index,
-      t->dtype, bounds, const_true(), realize);
-    // alignment requirement, only useful for compute
-    for (size_t i = 0; i < num_schedulable_dims(); ++i) {
-      auto it = stage->iter_var_attrs.find(this->axis[i]);
-      if (it != stage->iter_var_attrs.end()) {
-        IterVarAttr attr = (*it).second;
-        if (attr->dim_align_factor != 0) {
-          Array<PrimExpr> tuple = {static_cast<int>(i),
-                               attr->dim_align_factor,
-                               attr->dim_align_offset};
-          realize = ir::AttrStmtNode::make(
-              t, ir::attr::buffer_dim_align,
-              CallNode::make(DataType::Handle(),
-                             ir::intrinsic::tvm_tuple,
-                             tuple, CallNode::Intrinsic),
-              realize);
-        }
-      }
-    }
-  }
-  return realize;
-}
-
-size_t ComputeOpNode::num_schedulable_dims() const {
-  return axis.size();
-}
-
-// Build a reduction body.
-void MakeReduction(const ComputeOpNode* op,
-                   const Array<Tensor>& tensors,
-                   Stmt* init,
-                   Stmt* provide) {
-  Array<PrimExpr>  args;
-  for (IterVar iv : op->axis) {
-    args.push_back(iv->var);
-  }
-  std::vector<Stmt> inits, provides;
-
-  size_t size = op->body.size();
-  const ReduceNode* reduce = op->body[0].as<ReduceNode>();
-  CHECK(reduce);
-  const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
-  CHECK(combiner);
-  Array<PrimExpr> lhs;
-  for (size_t i = 0; i < size; ++i) {
-    lhs.push_back(tensors[i](args));
-  }
-  Array<PrimExpr> init_value = combiner->identity_element;
-  Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
-  for (size_t i = 0; i < size; ++i) {
-    Tensor t = tensors[i];
-    inits.emplace_back(ProvideNode::make(
-          t->op, t->value_index, init_value[i], args));
-    provides.emplace_back(ProvideNode::make(
-          t->op, t->value_index, update_value[i], args));
-  }
-  *init = SeqStmt::Flatten(inits);
-  *provide = SeqStmt::Flatten(provides);
-  if (!is_one(reduce->condition)) {
-    *provide = IfThenElseNode::make(reduce->condition, *provide);
-  }
-}
-
-// Normal computation.
-Stmt MakeProvide(const ComputeOpNode* op,
-                 const Tensor& t) {
-  Array<PrimExpr> args;
-  for (IterVar iv : op->axis) {
-    args.push_back(iv->var);
-  }
-  return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args);
-}
-
-Stmt MakeComputeStmt(const ComputeOpNode* self,
-                     const Stage& stage,
-                     const std::unordered_map<IterVar, Range>& dom_map,
-                     bool debug_keep_trivial_loop) {
-  // grab the nest structure
-  ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
-  // Normal loop structure
-  n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
-  n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
-  if (self->reduce_axis.size() != 0) {
-    // make reduction.
-    Stmt init, provide;
-    Array<Tensor> source;
-    for (size_t i = 0; i < self->body.size(); ++i) {
-      source.push_back(stage->op.output(i));
-    }
-    MakeReduction(self, source, &init, &provide);
-    init = MergeNest(n.init_nest, init);
-    init = op::Substitute(init, n.init_vmap);
-    // common nest
-    std::vector<std::vector<Stmt> > common(
-        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
-    std::vector<std::vector<Stmt> > reduce(
-        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
-    provide = MergeNest(reduce, provide);
-    if (debug_keep_trivial_loop) {
-      provide = MergeNest(common, provide);
-    } else {
-      provide = MergeNest(common, SeqStmt::Flatten(init, provide));
-    }
-    // run substitution in the on the full nest, because  loop condition
-    // could depend on outer loops.
-    return op::Substitute(provide, n.main_vmap);
-  } else {
-    std::vector<Stmt> provides;
-    for (size_t i = 0; i < self->body.size(); ++i) {
-      provides.emplace_back(MakeProvide(self, stage->op.output(i)));
-    }
-    Stmt provide = SeqStmt::Flatten(provides);
-    provide = MergeNest(n.main_nest, provide);
-    // run substitution in the on the full nest, because  loop condition
-    // could depend on outer loops.
-    return op::Substitute(provide, n.main_vmap);
-  }
-}
-
-enum class ComputeType {
-  kNormal,
-  kCrossThreadReduction,
-  kTensorize
-};
-
-ComputeType DetectComputeType(const ComputeOpNode* self,
-                              const Stage& stage) {
-  // Verify correctness of leaf nest.
-  int normal_red = 0, thread_red = 0, tensorize = 0;
-
-  for (IterVar iv : stage->leaf_iter_vars) {
-    IterVarAttr attr;
-    auto it = stage->iter_var_attrs.find(iv);
-    if (it != stage->iter_var_attrs.end()) {
-      attr = (*it).second;
-    }
-    if (attr.defined() && attr->iter_type == kTensorized) {
-      ++tensorize;
-    }
-    if (iv->iter_type == kCommReduce) {
-      if (attr.defined() && attr->bind_thread.defined()) {
-        ++thread_red;
-      } else {
-        ++normal_red;
-      }
-    } else {
-      CHECK_EQ(thread_red, 0)
-          << "Cross thread reduce cannot swap with normal data axis";
-    }
-  }
-  if (tensorize != 0) {
-    CHECK(thread_red == 0)
-        << "Cannot mix cross thread reduction with Tensorize";
-    return ComputeType::kTensorize;
-  }
-  CHECK(normal_red == 0 || thread_red == 0)
-      << "Cannot mix normal reduction with thread reduce";
-  if (thread_red != 0) {
-    return ComputeType::kCrossThreadReduction;
-  } else {
-    return ComputeType::kNormal;
-  }
-}
-
-// implement the provide utility.
-Stmt ComputeOpNode::BuildProvide(
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    bool debug_keep_trivial_loop) const {
-  CHECK_EQ(stage->op.operator->(), this);
-  ComputeType ctype = DetectComputeType(this, stage);
-  if (ctype == ComputeType::kCrossThreadReduction) {
-    // specially handle cross thread reduction.
-    return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
-  } else if (ctype == ComputeType::kTensorize) {
-    return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
-  } else {
-    return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
-  }
-}
-
-ComputeLoopNest ComputeLoopNest::make(
-    const BaseComputeOpNode* self,
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    bool debug_keep_trivial_loop) {
-  CHECK_EQ(stage->op.operator->(), self);
-  ComputeLoopNest ret;
-  // make main loop nest
-  ret.main_nest = op::MakeLoopNest(
-      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
-      debug_keep_trivial_loop);
-  ret.main_predicates = schedule::MakeBoundCheck(
-      stage, dom_map, ret.main_vmap, false,
-      std::unordered_set<IterVar>());
-  for (auto& e : ret.main_predicates) {
-    e = likely(e);
-  }
-  if (stage->store_predicate.defined()) {
-    ret.main_predicates.push_back(stage->store_predicate);
-  }
-  if (self->reduce_axis.size() != 0) {
-    // try to find the location to insert the initialization.
-    // Fuse the initialization and provide loop when possible.
-    std::unordered_map<IterVar, int> update_state;
-    for (IterVar iv : self->reduce_axis) {
-      update_state[iv] = 2;
-    }
-    for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
-      update_state[self->axis[i]] = 1;
-    }
-    // find which iter var is related to reduction and which is related to axis.
-    schedule::PassDownBitMaskOr(stage, &update_state);
-    auto leaf_iter_vars = stage->leaf_iter_vars;
-    // first first loop that is related to reduction.
-    size_t begin_loop = leaf_iter_vars.size();
-    for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
-      auto iv = leaf_iter_vars[i];
-      int flag = update_state.at(iv);
-      if ((flag & 2) != 0) {
-        begin_loop = i; break;
-      }
-      ret.init_vmap[iv] = ret.main_vmap.at(iv);
-    }
-    ret.num_common_loop = begin_loop;
-    // skip loops that are related to reduction and are unrelated to axis.
-    std::unordered_set<IterVar> skip_iter;
-    for (auto kv : update_state) {
-      int flag = kv.second;
-      if (flag == 2) skip_iter.insert(kv.first);
-    }
-    ret.init_nest = op::MakeLoopNest(
-        stage, dom_map, begin_loop, true,
-        skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
-    ret.init_predicates = schedule::MakeBoundCheck(
-        stage, dom_map, ret.init_vmap, true, skip_iter);
-    for (auto& e : ret.init_predicates) {
-      e = likely(e);
-    }
-  } else {
-    CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
-    ret.num_common_loop = stage->leaf_iter_vars.size();
-  }
-  // copy elison here.
-  return ret;
-}
-
-namespace {
-/*!
- * \brief Verify if ComputeOp is valid with respect to Reduce operations.
- *
- *  The following two properties are verified:
- *  (1) All Reduce operations must exist at top level.
- *  (2) For a list of operations, if one is Reduce, then the others
- *      must be Reduce as well; and their inputs should have the
- *      same attribute except value_index.
- */
-class ComputeVerifier final : protected ir::ExprVisitor {
- public:
-  /// Special member functions
-  //@{
-  explicit ComputeVerifier(const ComputeOpNode* compute)
-      : compute_(compute), reduce_(compute->body[0].as<ir::ReduceNode>()) {}
-  virtual ~ComputeVerifier() = default;
-  ComputeVerifier(const ComputeVerifier&) = delete;
-  ComputeVerifier(ComputeVerifier&&) = delete;
-  ComputeVerifier& operator=(const ComputeVerifier&) = delete;
-  ComputeVerifier& operator=(ComputeVerifier&&) = delete;
-  //@}
-
-  /// Interface to perform compute verification
-  void Run() {
-    for (const PrimExpr e : compute_->body) {
-      // Check for consistency of top level reductions
-      const ir::ReduceNode* reduce = e.as<ir::ReduceNode>();
-      CHECK((reduce && reduce_) || (!reduce && !reduce_))
-          << "All ComputeOp should be consistent "
-          << "with being Reduce operation or not.";
-
-      if (reduce && reduce_) {
-        CHECK(ReduceEqual(reduce, reduce_))
-            << "The Reduce inputs of ComputeOp should "
-            << "have the same attribute except value_index";
-      }
-
-      level_ = 0;
-      ExprVisitor::VisitExpr(e);
-    }
-  }
-
- protected:
-  /// Visitor implementation
-  //@{
-  void VisitExpr(const PrimExpr& n) final {
-    ++level_;
-    ExprVisitor::VisitExpr(n);
-    --level_;
-  }
-
-  void VisitExpr_(const ir::ReduceNode* op) final {
-    // Check for non top level reductions
-    CHECK(0 == level_)
-        << "Reductions are only allowed at the top level of compute. "
-        << "Please create another tensor for further composition.";
-  }
-  //@}
-
- private:
-  const ComputeOpNode* compute_{nullptr};  ///< ComputeOpNode to verify
-  const ir::ReduceNode* reduce_{nullptr};      ///< Top level Reduce operation
-  int level_{0};                           ///< Level of op being processed
-};
-}  // namespace
-
-/// Verify if ComputeOp is valid with respect to Reduce operations.
-static void VerifyComputeOp(const ComputeOpNode* op) {
-  ComputeVerifier v(op);
-  v.Run();
-}
-
-Stmt TransformUpdate(const Stage& stage,
-                     const std::unordered_map<IterVar, Range>& dom_map,
-                     const ComputeLoopNest& n,
-                     Stmt body,
-                     Stmt update) {
-  Array<PrimExpr> conds;
-  std::unordered_set<const VarNode*> banned;
-  for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
-    IterVar iv = stage->leaf_iter_vars[i];
-    auto iit = stage->iter_var_attrs.find(iv);
-    if (iit != stage->iter_var_attrs.end()) {
-      const IterVarAttr& attr = (*iit).second;
-      if (attr->iter_type == kTensorized) {
-        break;
-      }
-    }
-    if (iv->iter_type == kCommReduce) {
-      auto vit = dom_map.find(iv);
-      CHECK(vit != dom_map.end());
-      const Range& vrange = vit->second;
-      conds.push_back(likely(iv->var > vrange->min));
-      banned.insert(iv->var.get());
-    }
-  }
-  for (const PrimExpr& pred : n.main_predicates) {
-    if (ir::ExprUseVar(pred, banned)) {
-      LOG(FATAL) << "Tensorize update transform failed, the condition "
-                 << pred << " has a conflict with the reset condition";
-    }
-  }
-
-  return IfThenElseNode::make(arith::ComputeReduce<ir::OrNode>(conds, const_true(1)),
-                          update, body);
-}
-}  // namespace tvm
diff --git a/src/op/compute_op.h b/src/op/compute_op.h
deleted file mode 100644 (file)
index 3fe98e8..0000000
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Helper utilities to implement compute_op.
- * \file compute_op.h
- */
-#ifndef TVM_OP_COMPUTE_OP_H_
-#define TVM_OP_COMPUTE_OP_H_
-
-#include <tvm/ir.h>
-#include <tvm/expr.h>
-#include <tvm/operation.h>
-#include <vector>
-#include <unordered_map>
-
-namespace tvm {
-// loop nest structure for general compute
-// This the loop nest structured used in compute.
-// Does not include the loop body.
-struct ComputeLoopNest {
-  // The common number of loops between init and main
-  size_t num_common_loop;
-  // predicates for the initialize loop
-  std::vector<PrimExpr> init_predicates;
-  // Initialization nest involved.
-  std::vector<std::vector<Stmt> > init_nest;
-  // Value map for the init code
-  std::unordered_map<IterVar, PrimExpr> init_vmap;
-  // Predicates for the main update loop
-  std::vector<PrimExpr> main_predicates;
-  // The general loop nest
-  std::vector<std::vector<Stmt> > main_nest;
-  // Value map for the IterVar.
-  std::unordered_map<IterVar, PrimExpr> main_vmap;
-
-  /*!
-   * \brief constructor to build ComputeOpNest
-   * \param self The pointer to compute op.
-   * \param stage The scxhedule stage.
-   * \param dom_map The domain map.
-   * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
-   * \return The constructed loop nest
-   */
-  static ComputeLoopNest make(
-      const BaseComputeOpNode* self,
-      const Stage& stage,
-      const std::unordered_map<IterVar, Range>& dom_map,
-      bool debug_keep_trivial_loop);
-};
-
-/*!
- * \brief Build body of compute for cross thread reduction pattern.
- * \param self The pointer to ComputeOpNode
- * \param stage The schedule stage.
- * \param dom_map The domain map.
- * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
- * \return The created statement.
- */
-Stmt MakeCrossThreadReduction(
-    const ComputeOpNode* self,
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    bool debug_keep_trivial_loop);
-
-/*!
- * \brief Build body of compute for tensorization.
- * \param self The pointer to ComputeOpNode
- * \param stage The schedule stage.
- * \param dom_map The domain map.
- * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
- * \return The created statement.
- */
-Stmt MakeTensorize(const ComputeOpNode* self,
-                   const Stage& stage,
-                   const std::unordered_map<IterVar, Range>& dom_map,
-                   bool debug_keep_trivial_loop);
-
-/*!
- * \brief Transform the update part when there is no init func in tensorizing
- * \param stage The stage for tensorizing.
- * \param dom_map The range of each iter var.
- * \param n The loop nest structured used in compute.
- * \param body The body func in tensorize intrin
- * \param update The update func in tensorize intrin
- * \return Transformed result.
- */
-Stmt TransformUpdate(const Stage& stage,
-                     const std::unordered_map<IterVar, Range>& dom_map,
-                     const ComputeLoopNest& n,
-                     Stmt body,
-                     Stmt update);
-}  // namespace tvm
-
-#endif  // TVM_OP_COMPUTE_OP_H_
diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc
deleted file mode 100644 (file)
index 9de4bde..0000000
+++ /dev/null
@@ -1,116 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Logics related to cross thread reduction, used by ComputeOpNode.
- * \file cross_thread_reduction.cc
- */
-#include <tvm/ir_pass.h>
-#include "compute_op.h"
-#include "op_util.h"
-
-namespace tvm {
-using namespace ir;
-
-Stmt MakeCrossThreadReduction(
-    const ComputeOpNode* self,
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    bool debug_keep_trivial_loop) {
-  Array<PrimExpr>  args;
-  for (IterVar iv : self->axis) {
-    args.push_back(iv->var);
-  }
-  std::unordered_map<IterVar, PrimExpr> value_map;
-  auto nest = op::MakeLoopNest(
-      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop);
-  auto conds = schedule::MakeBoundCheck(
-      stage, dom_map, value_map, false,
-      std::unordered_set<IterVar>());
-
-  size_t size = self->body.size();
-  CHECK_GT(size, 0);
-  std::vector<const ReduceNode*> reduces(size);
-  for (size_t i = 0; i < size; ++i) {
-    const ReduceNode* reduce = self->body[i].as<ReduceNode>();
-    CHECK(reduce);
-    reduces[i] = reduce;
-  }
-  PrimExpr cond = reduces[0]->condition;
-  for (PrimExpr v : conds) {
-    cond = cond && v;
-  }
-  Array<PrimExpr> freduce_args;
-  freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
-  for (size_t i = 0; i < size; ++i) {
-    freduce_args.push_back(reduces[0]->source[i]);
-  }
-  freduce_args.push_back(cond);
-  std::vector<Var> res_handles(size);
-  for (size_t idx = 0; idx < size; ++idx) {
-    res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle());
-    freduce_args.push_back(res_handles[idx]);
-  }
-
-  for (IterVar iv : stage->leaf_iter_vars) {
-    if (iv->iter_type == kCommReduce) {
-      auto it = stage->iter_var_attrs.find(iv);
-      if (it != stage->iter_var_attrs.end() &&
-          (*it).second->bind_thread.defined()) {
-        IterVar tv = (*it).second->bind_thread;
-        freduce_args.push_back(tv->var);
-      }
-    }
-  }
-  // Checks for the thread.
-  std::vector<PrimExpr> thread_head_check;
-  if (stage->store_predicate.defined()) {
-    thread_head_check.emplace_back(stage->store_predicate);
-  }
-
-  Stmt reduce_body = EvaluateNode::make(CallNode::make(
-      DataType::Handle(),
-      ir::intrinsic::tvm_thread_allreduce,
-      freduce_args, CallNode::Intrinsic));
-  reduce_body = AttrStmtNode::make(
-      reduces[0]->combiner,
-      attr::reduce_scope,
-      make_zero(DataType::Handle()),
-      reduce_body);
-  std::vector<Stmt> assigns(size);
-  for (size_t idx = 0; idx < size; ++idx) {
-    DataType t = reduces[idx]->dtype;
-    assigns[idx] = ProvideNode::make(
-      stage->op, idx,
-      LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
-  }
-  Stmt assign_body = SeqStmt::Flatten(assigns);
-  assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
-  assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
-  Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
-  for (size_t idx = size; idx != 0; --idx) {
-    body = AllocateNode::make(
-      res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
-    body = AttrStmtNode::make(
-      res_handles[idx - 1], attr::storage_scope, StringImmNode::make("local"), body);
-  }
-  body = op::Substitute(body, value_map);
-  return MergeNest(nest, body);
-}
-}  // namespace tvm
diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc
deleted file mode 100644 (file)
index fb9f491..0000000
+++ /dev/null
@@ -1,184 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief External computation rule.
- * \file extern_op.cc
- */
-#include <tvm/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir.h>
-#include <unordered_set>
-#include "op_util.h"
-
-namespace tvm {
-using namespace ir;
-// ExternOpNode
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<ExternOpNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const ExternOpNode*>(node.get());
-    p->stream << "extern(" << op->name << ", " << op << ")";
-  });
-
-TVM_REGISTER_NODE_TYPE(ExternOpNode);
-
-int ExternOpNode::num_outputs() const {
-  return static_cast<int>(output_placeholders.size());
-}
-
-Array<IterVar> ExternOpNode::root_iter_vars() const {
-  return {};
-}
-
-DataType ExternOpNode::output_dtype(size_t i) const {
-  return output_placeholders[i]->dtype;
-}
-
-Array<PrimExpr> ExternOpNode::output_shape(size_t i) const {
-  return output_placeholders[i]->shape;
-}
-
-
-Operation ExternOpNode::make(std::string name,
-                             std::string tag,
-                             Map<std::string, ObjectRef> attrs,
-                             Array<Tensor> inputs,
-                             Array<Buffer> input_placeholders,
-                             Array<Buffer> output_placeholders,
-                             Stmt body) {
-  if (!attrs.defined()) {
-    attrs = Map<std::string, ObjectRef>();
-  }
-  auto n = make_object<ExternOpNode>();
-  n->name = std::move(name);
-  n->tag = std::move(tag);
-  n->attrs = std::move(attrs);
-  CHECK_EQ(inputs.size(), input_placeholders.size());
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
-    CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size());
-    for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) {
-        CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim]));
-    }
-    CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
-  }
-  n->inputs = std::move(inputs);
-  n->input_placeholders = std::move(input_placeholders);
-  n->output_placeholders = std::move(output_placeholders);
-  n->body = std::move(body);
-  return Operation(n);
-}
-
-Array<Tensor> ExternOpNode::InputTensors() const {
-  return inputs;
-}
-
-Operation ExternOpNode::ReplaceInputs(
-    const Operation& self,
-    const std::unordered_map<Tensor, Tensor>& rmap) const {
-  CHECK_EQ(self.operator->(), this);
-  auto n = make_object<ExternOpNode>(*this);
-  n->body = op::ReplaceTensor(this->body, rmap);
-  for (size_t i = 0; i < n->inputs.size(); ++i) {
-    Tensor t = n->inputs[i];
-    if (rmap.count(t)) {
-      n->inputs.Set(i, rmap.at(t));
-    }
-  }
-
-  if (body.same_as(n->body) &&
-      inputs.same_as(n->inputs)) {
-    return self;
-  } else {
-    return Operation(n);
-  }
-}
-
-void ExternOpNode::PropBoundToInputs(
-    const Operation& self,
-    arith::Analyzer* analyzer,
-    const std::unordered_map<const VarNode*, IntSet>& dom_map,
-    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
-  for (Tensor t : this->inputs) {
-    auto it = out_dom_map->find(t);
-    if (it == out_dom_map->end()) continue;
-    TensorDom& dom = it->second;
-    for (size_t i = 0; i < t->shape.size(); ++i) {
-      dom.data[i].emplace_back(IntSet::range(
-          Range::make_by_min_extent(
-              make_const(t->shape[i].dtype(), 0), t->shape[i])));
-    }
-  }
-}
-
-void ExternOpNode::GatherBound(
-    const Operation& self,
-    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
-    std::unordered_map<IterVar, Range>* out_dom_map) const {
-}
-
-Stmt ExternOpNode::BuildRealize(
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& realize_map,
-    const Stmt& body) const {
-  CHECK_EQ(stage->op.get(), this);
-  Stmt realize_body = body;
-  for (int k = 0; k < num_outputs(); ++k) {
-    Tensor t = stage->op.output(k);
-    Region bounds;
-    for (size_t i = 0; i < t->shape.size(); ++i) {
-      bounds.push_back(
-          Range::make_by_min_extent(
-              make_const(t->shape[i].dtype(), 0), t->shape[i]));
-    }
-    realize_body = ir::RealizeNode::make(
-        t->op, t->value_index, t->dtype,
-        bounds, const_true(), realize_body);
-  }
-  return realize_body;
-}
-
-Stmt ExternOpNode::BuildProvide(
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    bool debug_keep_trivial_loop) const {
-  CHECK_EQ(stage->op.operator->(), this);
-  Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
-  auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
-    Array<ObjectRef> bind_spec;
-    Array<PrimExpr> tuple;
-    bind_spec.push_back(buffer);
-    bind_spec.push_back(tensor);
-    for (size_t k = 0; k < buffer->shape.size(); ++k) {
-      tuple.push_back(make_const(buffer->shape[k].dtype(), 0));
-      tuple.push_back(buffer->shape[k]);
-    }
-    ret = AttrStmtNode::make(
-        bind_spec, attr::buffer_bind_scope,
-        CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
-  };
-  for (size_t i = output_placeholders.size(); i != 0; --i) {
-    f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
-  }
-  for (size_t i = inputs.size(); i != 0; --i) {
-    f_push_bind(input_placeholders[i - 1], inputs[i - 1]);
-  }
-  return ret;
-}
-}  // namespace tvm
diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc
deleted file mode 100644 (file)
index 8687ad6..0000000
+++ /dev/null
@@ -1,512 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Hybrid computation rule.
- * \file hybrid_op.cc
- */
-#include <tvm/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/ir_pass.h>
-#include <tvm/expr_operator.h>
-#include <unordered_set>
-#include <string>
-#include <utility>
-#include "op_util.h"
-#include "hybrid_op.h"
-
-namespace tvm {
-using namespace ir;
-// HybridOpNode
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<HybridOpNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const HybridOpNode*>(node.get());
-    p->stream << "hybrid(" << op->name << ", " << op << ")";
-  });
-
-TVM_REGISTER_NODE_TYPE(HybridOpNode);
-
-int HybridOpNode::num_outputs() const {
-  return static_cast<int>(outputs.size());
-}
-
-Array<IterVar> HybridOpNode::root_iter_vars() const {
-  return this->axis;
-}
-
-DataType HybridOpNode::output_dtype(size_t i) const {
-  return outputs[i]->dtype;
-}
-
-Array<PrimExpr> HybridOpNode::output_shape(size_t i) const {
-  return outputs[i]->shape;
-}
-
-
-Operation HybridOpNode::make(std::string name,
-                             std::string tag,
-                             Map<std::string, ObjectRef> attrs,
-                             Array<Tensor> inputs,
-                             Array<Tensor> outputs,
-                             Stmt body) {
-  if (!attrs.defined()) {
-    attrs = Map<std::string, ObjectRef>();
-  }
-  auto n = make_object<HybridOpNode>();
-  n->name = std::move(name);
-  n->tag = std::move(tag);
-  n->attrs = std::move(attrs);
-  n->inputs = std::move(inputs);
-  n->outputs = std::move(outputs);
-  n->axis = op::GatherLoopVars(body);
-  n->body = std::move(body);
-  Operation res = Operation(n);
-  return res;
-}
-
-Array<Tensor> HybridOpNode::InputTensors() const {
-  // Because input tensors could be potentially inlined into hybrid scripts,
-  // we need to check if all input tensors are used in the body.
-  std::unordered_set<Tensor> orig_inputs;
-  for (auto t : inputs) {
-    orig_inputs.insert(t);
-  }
-  std::unordered_set<Tensor> visited;
-  Array<Tensor> curr_inputs;
-  ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
-      const ir::CallNode *call = n.as<ir::CallNode>();
-      if (call != nullptr && call->func.defined()) {
-        Tensor t = Downcast<Operation>(call->func).output(call->value_index);
-        if (orig_inputs.count(t) && !visited.count(t)) {
-          curr_inputs.push_back(t);
-          visited.insert(t);
-        }
-      }
-  });
-  return curr_inputs;
-}
-
-Operation HybridOpNode::ReplaceInputs(
-    const Operation &self,
-    const std::unordered_map<Tensor, Tensor> &rmap) const {
-  CHECK_EQ(self.operator->(), this);
-  auto n = make_object<HybridOpNode>(*this);
-  n->body = op::ReplaceTensor(this->body, rmap);
-  for (size_t i = 0; i < n->inputs.size(); ++i) {
-    Tensor t = n->inputs[i];
-    if (rmap.count(t)) {
-      n->inputs.Set(i, rmap.at(t));
-    }
-  }
-
-  if (body.same_as(n->body) &&
-      inputs.same_as(n->inputs)) {
-    return self;
-  } else {
-    return Operation(n);
-  }
-}
-
-void HybridOpNode::PropBoundToInputs(
-    const Operation &self,
-    arith::Analyzer* analyzer,
-    const std::unordered_map<const VarNode*, IntSet> &dom_map,
-    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
-  auto curr_inputs = InputTensors();
-  for (Tensor t : curr_inputs) {
-    auto it = out_dom_map->find(t);
-    if (it == out_dom_map->end()) continue;
-    TensorDom &dom = it->second;
-    for (size_t i = 0; i < t->shape.size(); ++i) {
-      dom.data[i].emplace_back(IntSet::range(
-          Range::make_by_min_extent(
-              make_const(t->shape[i].dtype(), 0), t->shape[i])));
-    }
-  }
-}
-
-void HybridOpNode::GatherBound(
-    const Operation &self,
-    const std::unordered_map<Tensor, TensorDom> &tensor_dom,
-    std::unordered_map<IterVar, Range>* out_dom_map) const {
-  for (auto iter_var : axis) {
-    CHECK(!out_dom_map->count(iter_var));
-    out_dom_map->operator[](iter_var) = iter_var->dom;
-  }
-}
-
-Stmt HybridOpNode::BuildRealize(
-    const Stage &stage,
-    const std::unordered_map<IterVar, Range> &realize_map,
-    const Stmt &body) const {
-  // TODO(@were): Add attribute inject here and remove it from hybrid parser.
-  CHECK_EQ(stage->op.get(), this);
-  Stmt realize_body = body;
-  for (int k = 0; k < num_outputs(); ++k) {
-    Tensor t = stage->op.output(k);
-    Region bounds;
-    for (size_t i = 0; i < t->shape.size(); ++i) {
-      bounds.push_back(
-          Range::make_by_min_extent(
-              make_const(t->shape[i].dtype(), 0), t->shape[i]));
-    }
-    realize_body = ir::RealizeNode::make(
-        t->op, t->value_index, t->dtype,
-        bounds, const_true(), realize_body);
-  }
-  return realize_body;
-}
-
-Stmt HybridOpNode::BuildProvide(
-    const Stage &stage,
-    const std::unordered_map<IterVar, Range> &dom_map,
-    bool debug_keep_trivial_loop) const {
-  CHECK_EQ(stage->op.operator->(), this);
-  Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
-  std::unordered_map<Tensor, Tensor> rmap;
-  for (int i = 0; i < this->num_outputs(); ++i) {
-    rmap[outputs[i]] = stage->op.output(i);
-  }
-  auto n = make_object<HybridOpNode>(*this);
-  /* This is a story little bit complicated.
-   * The following two lines of codes replace output tensors' usage.
-   * This is the simplest way I (@were) can come up with to glue
-   * hybrid operation node to TVM op system.
-   * In hybrid script all the tensors, especially the output tensors,
-   * have their own names defined by the users. However, In TVM
-   * conventional ops:
-   *   1. Output tensors refer the corresponding op node so that the output
-   *      tensors have the same names as the operation produces them.
-   *   2. Once OpNode is wrapped up by an Operation node, it is finalized.
-   *      Later access will be from a const OpNode*.
-   * This is a chicken-egg paradox. It is impossible to put the output
-   * tensors into the function body without forming the op node. The
-   * function body is immutable after the node is formed.
-   *
-   * Finally, I decided to resolve this issue "lazily". During the
-   * pipeline of compilation, this stage is a very preliminary stage.
-   * Technically, it is before Phase 0. The actual tensors will be replaced
-   * here.
-   * Thus, the operation body is slightly different from the Phase 0 body.
-   * This is a major difference that HybridOpNode is NOT the same as
-   * ExternOpNode.
-   * */
-  ret = op::ReplaceTensor(ret, rmap);
-  ret = op::ReplaceProvideTensor(ret, rmap);
-
-  ret = op::ApplySchedule(stage, dom_map, ret);
-  return ret;
-}
-
-namespace op {
-
-
-Stmt ApplyLoopShapes(const Stage &stage,
-                 const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
-  class LoopSpliter : public StmtExprMutator {
-    PrimExpr factor;
-    const VarNode *parent;
-    IterVar inner, outer;
-
-   public:
-    bool splitted;
-    LoopSpliter(const SplitNode *split,
-                const std::unordered_map<IterVar, Range> &dom_map) :
-      factor(split->factor), splitted(false) {
-      parent = split->parent->var.get();
-
-      auto &inner_ = split->inner;
-      CHECK(dom_map.count(inner_));
-      auto &inner_dom = dom_map.find(inner_)->second;
-      CHECK(is_const_int(inner_dom->min, 0));
-
-      auto &outer_ = split->outer;
-      CHECK(dom_map.count(outer_));
-      auto &outer_dom = dom_map.find(outer_)->second;
-      CHECK(is_const_int(outer_dom->min, 0));
-
-      inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type);
-      outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
-    }
-
-    Stmt VisitStmt_(const ForNode *op) final {
-      if (op->loop_var.get() == parent) {
-        std::unordered_map<const VarNode *, PrimExpr> rmap;
-        rmap[op->loop_var.get()] = inner + outer * factor;
-        Stmt ret = ir::Substitute(op->body, rmap);
-        PrimExpr cond = likely(outer * factor < (op->extent - inner));
-        ret = IfThenElseNode::make(cond, ret);
-        ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent,
-                        IterVarTypeToForType(inner->iter_type), op->device_api, ret);
-        ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent,
-                        IterVarTypeToForType(outer->iter_type), op->device_api, ret);
-        splitted = true;
-        return ret;
-      }
-      return StmtExprMutator::VisitStmt_(op);
-    }
-  };
-
-  class LoopFuser : public StmtExprMutator {
-    const IterVar &parent;
-    const VarNode *inner;
-    const VarNode *outer;
-    bool under_outer;
-    PrimExpr extent;
-
-   public:
-    bool fused;
-    explicit LoopFuser(const FuseNode *fuse_)
-      : parent(fuse_->fused), inner(fuse_->inner->var.get()),
-        outer(fuse_->outer->var.get()), under_outer(false),
-        extent(0), fused(false) {}
-
-    // TODO(@were): Handle imperfect loops
-    Stmt VisitStmt_(const ForNode* op) final {
-      if (op->loop_var.get() == inner) {
-        CHECK(under_outer);
-        std::unordered_map<const VarNode *, PrimExpr> rmap;
-        rmap[op->loop_var.get()] = indexmod(parent, op->extent);
-        extent = op->extent;
-        fused = true;
-        return ir::Substitute(op->body, rmap);
-      } else if (op->loop_var.get() == outer) {
-        under_outer = true;
-        Stmt body = this->VisitStmt(op->body);
-        std::unordered_map<const VarNode *, PrimExpr> rmap;
-        rmap[op->loop_var.get()] = indexdiv(parent, extent);
-        body = ir::Substitute(body, rmap);
-        under_outer = false;
-        return ForNode::make(parent->var, PrimExpr(0), extent * op->extent,
-                         op->for_type, op->device_api, body);
-      } else if (under_outer) {
-        Stmt body = this->VisitStmt(op->body);
-        std::unordered_map<const VarNode *, PrimExpr> rmap;
-        rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
-        body = ir::Substitute(body, rmap);
-        extent = extent * op->extent;
-        return body;
-      }
-      return StmtExprMutator::VisitStmt_(op);
-    }
-  };
-
-  for (auto &rel : stage->relations) {
-    if (const SplitNode *split = rel.as<SplitNode>()) {
-      LoopSpliter Spliter(split, dom_map);
-      stmt = Spliter(stmt);
-      CHECK(Spliter.splitted);
-    } else if (const FuseNode *fuse = rel.as<FuseNode>()) {
-      LoopFuser Fuser(fuse);
-      stmt = Fuser(stmt);
-      CHECK(Fuser.fused);
-    }
-  }
-
-  return stmt;
-}
-
-Stmt ApplyLoopAnnotations(const Stage &stage,
-                          const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
-  class LoopAnnotator : public StmtMutator {
-    const VarNode *var;
-    const IterVarAttr &attr;
-
-   public:
-    LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
-
-    Stmt VisitStmt_(const ForNode *op) final {
-      if (op->loop_var.get() == var) {
-        if (attr->bind_thread.defined()) {
-          const auto &iter_var = attr->bind_thread;
-          if (iter_var->dom.defined()) {
-            CHECK(is_const_int(iter_var->dom->min, 0));
-            CHECK(Equal(iter_var->dom->extent, op->extent))
-              << "Thread extent and loop extent mismatch!\n";
-          }
-          std::unordered_map<const VarNode *, PrimExpr> rmap;
-          rmap[op->loop_var.get()] = iter_var;
-          Stmt body = ir::Substitute(op->body, rmap);
-          return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
-        } else {
-          return ForNode::make(op->loop_var, op->min, op->extent,
-                           IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
-        }
-      }
-      return StmtMutator::VisitStmt_(op);
-    }
-  };
-
-  for (auto &iter_var : stage->leaf_iter_vars) {
-    bool need_change = false;
-    int found = 0;
-
-    const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
-    const VarNode *var = actual->var.get();
-    ForType expected = IterVarTypeToForType(iter_var->iter_type);
-    IterVarAttr attr;
-    if (stage->iter_var_attrs.count(iter_var)) {
-      attr = stage->iter_var_attrs[iter_var];
-      expected = IterVarTypeToForType(attr->iter_type);
-    }
-
-    PostOrderVisit(stmt,
-    [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
-      if (const ForNode *op = node.as<ForNode>()) {
-        if (op->loop_var.get() == var) {
-          ++found;
-          need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
-        }
-      }
-    });
-
-    CHECK_EQ(found, 1) << " iter var should be found exactly once!";
-    if (need_change) {
-      stmt = LoopAnnotator(var, attr)(std::move(stmt));
-    }
-  }
-  return stmt;
-}
-
-Stmt ApplyLoopOrder(const Stage &stage,
-                    const std::unordered_map<IterVar, Range> &dom_map,
-                    const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
-  std::vector<const VarNode*> current_order;
-  PostOrderVisit(stmt, [&current_order](const ObjectRef& node) {
-    if (const ForNode *op = node.as<ForNode>())
-      current_order.push_back(op->loop_var.get());
-  });
-  std::reverse(current_order.begin(), current_order.end());
-  auto &required_ord = stage->leaf_iter_vars;
-  CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
-  std::unordered_map<const VarNode *, IterVar> reorder;
-  bool need_reorder = false;
-  for (size_t i = 0; i < current_order.size(); ++i) {
-    auto &current = current_order[i];
-    const IterVar &iter_var = required_ord[i];
-    const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
-    CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n";
-    reorder[current] = required;
-    if (current != required->var.get()) {
-      need_reorder = true;
-    }
-  }
-
-  class LoopReorder : public StmtMutator {
-    const Stage &stage;
-    const std::unordered_map<IterVar, Range> &dom_map;
-    const std::unordered_map<const VarNode *, IterVar> &reorder;
-
-   public:
-    LoopReorder(const Stage &stage,
-                const std::unordered_map<IterVar, Range> &dom_map,
-                const std::unordered_map<const VarNode*, IterVar> &reorder)
-      : stage(stage), dom_map(dom_map), reorder(reorder) {}
-
-    Stmt VisitStmt_(const ForNode* op) final {
-      // Reorder from in to out
-      Stmt body_ = this->VisitStmt(op->body);
-      CHECK(reorder.count(op->loop_var.get()));
-      auto target = reorder.find(op->loop_var.get())->second;
-      if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
-        return GetRef<Stmt>(op);
-      const Stmt &body = op->body.same_as(body_) ? op->body : body_;
-      ForType for_type = IterVarTypeToForType(target->iter_type);
-      if (stage->iter_var_attrs.count(target)) {
-        for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
-      }
-      const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
-      return ForNode::make(target->var, range->min, range->extent,
-                       for_type, DeviceAPI::None, body);
-    }
-  };
-
-  if (need_reorder)
-    return LoopReorder(stage, dom_map, reorder)(stmt);
-
-  return stmt;
-}
-
-Stmt ApplySchedule(const Stage &stage,
-                   const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
-  // TODO(@were): Eliminate loop rebase in script parser and move the burden here
-  // Gather rebased variables
-  std::unordered_map<IterVar, IterVar> rebased;
-  for (auto rel : stage->relations) {
-    if (const auto* rebase = rel.as<RebaseNode>()) {
-      rebased[rebase->rebased] = rebase->parent;
-      CHECK(rebase->parent->dom.defined());
-      CHECK(dom_map.count(rebase->rebased));
-    }
-  }
-  stmt = ApplyLoopShapes(stage, dom_map, stmt);
-  stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt);
-  stmt = ApplyLoopAnnotations(stage, rebased, stmt);
-  return stmt;
-}
-
-std::vector<IterVar> GatherLoopVars(Stmt stmt) {
-  // TODO(@were): Write a comprehensive pass to analyze iter var types
-  std::vector<IterVar> res_;
-  PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
-    if (const ForNode *op = node.as<ForNode>()) {
-      Var loop_var(op->loop_var);
-      Range dom = Range::make_by_min_extent(op->min, op->extent);
-      res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type)));
-    }
-  });
-  std::reverse(res_.begin(), res_.end());
-  return res_;
-}
-
-// replacer to replace tensors' usage in Provide
-class ProviderReplacer : public ir::StmtMutator {
- public:
-  explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
-      : vmap_(vmap) {}
-
-  Stmt VisitStmt_(const ir::ProvideNode* op) final {
-    Tensor t = Downcast<Operation>(op->func).output(op->value_index);
-    auto it = vmap_.find(t);
-    if (it != vmap_.end()) {
-      Stmt ret = ir::ProvideNode::make(
-        it->second->op, it->second->value_index, op->value, op->args);
-      found = true;
-      return this->VisitStmt(ret);
-    }
-    return StmtMutator::VisitStmt_(op);
-  }
-
-  // whether it is found.
-  bool found{false};
-
- private:
-  const std::unordered_map<Tensor, Tensor> &vmap_;
-};
-
-Stmt ReplaceProvideTensor(Stmt stmt,
-                   const std::unordered_map<Tensor, Tensor> &replace) {
-  ProviderReplacer repl(replace);
-  Stmt ret = repl(stmt);
-  return repl.found ? ret : stmt;
-}
-}  // namespace op
-}  // namespace tvm
diff --git a/src/op/hybrid_op.h b/src/op/hybrid_op.h
deleted file mode 100644 (file)
index f180129..0000000
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Helper utilities to implement hybrid_op.
- * \file hybrid_op.h
- */
-#ifndef TVM_OP_HYBRID_OP_H_
-#define TVM_OP_HYBRID_OP_H_
-
-#include <tvm/expr.h>
-#include <tvm/schedule.h>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-#include "../pass/ir_util.h"
-#include "../pass/arg_binder.h"
-#include "../schedule/message_passing.h"
-
-
-namespace tvm {
-namespace op {
-
-/*!
- * \brief Find all the iteration variables in the given statement body.
- * \param stmt The body to be inspected.
- */
-std::vector<IterVar> GatherLoopVars(Stmt stmt);
-
-/*!
- * \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
- * \param stmt The statement to be processed.
- * \param replace The replacement rule.
- */
-Stmt ReplaceProvideTensor(Stmt stmt,
-                          const std::unordered_map<Tensor, Tensor>& replace);
-
-/*!
- * \brief Apply the schedule manipulation on the function body.
- * \param stmt The statement to be processed.
- * \param dom_map The extents of the iterative variables may be used.
- * \param stage The schedule information to be applied.
- */
-Stmt ApplySchedule(const Stage& stage,
-                   const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
-
-/*!
- * \brief Apply loop splits and fuses in the schedule on the function body.
- * \param stage The schedule information to be applied.
- * \param dom_map The extents of the iterative variables may be used.
- * \param stmt The statement to be processed.
- */
-Stmt ApplyLoopShapes(const Stage &stage,
-                     const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
-
-
-/*!
- * \brief Apply loop annotation in the schedule on the function body.
- * \param stage The schedule information to be applied.
- * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
- * \param stmt The statement to be processed.
- */
-Stmt ApplyLoopAnnotations(const Stage &stage,
-                          const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt);
-
-/*!
- * \brief Apply loop order in the schedule on the function body.
- * \param stage The schedule information to be applied.
- * \param dom_map The extents of the iterative variables may be used.
- * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
- * \param stmt The statement to be processed.
- */
-Stmt ApplyLoopOrder(const Stage &stage,
-                    const std::unordered_map<IterVar, Range> &dom_map,
-                    const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt);
-
-}  // namespace op
-}  // namespace tvm
-
-#endif  // TVM_OP_HYBRID_OP_H_
diff --git a/src/op/op_util.cc b/src/op/op_util.cc
deleted file mode 100644 (file)
index 52d3b5a..0000000
+++ /dev/null
@@ -1,270 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Utility to make loop nest.
- * \file op_util.cc
- */
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/operation.h>
-#include <string>
-#include "op_util.h"
-#include "../schedule/message_passing.h"
-#include "../arith/compute_expr.h"
-
-namespace tvm {
-namespace op {
-
-using namespace arith;
-using namespace ir;
-
-std::vector<std::vector<Stmt> >
-MakeLoopNest(const Stage& stage,
-             const std::unordered_map<IterVar, Range>& dom_map,
-             size_t begin_iter_pos,
-             bool new_loop_var,
-             const std::unordered_set<IterVar>& skip_iter,
-             std::unordered_map<IterVar, PrimExpr>* p_value_map,
-             bool debug_keep_trivial_loop) {
-  auto leaf_iter_vars = stage->leaf_iter_vars;
-  Stmt no_op = EvaluateNode::make(0);
-  // create the loop nest
-  std::vector<std::vector<Stmt> > nest;
-  nest.resize(leaf_iter_vars.size() + 1);
-  std::unordered_map<IterVar, PrimExpr>& value_map = *p_value_map;
-
-  for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
-    auto iv = leaf_iter_vars[i];
-    if (skip_iter.count(iv) || iv->iter_type == kOpaque) {
-      // skip this iteration.
-      value_map[iv] = iv->var;
-      continue;
-    }
-    // Bind iv could be another thread.
-    IterVar bind_iv = iv;
-    if (stage->iter_var_attrs.count(iv)) {
-      IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
-      if (bind_thread.defined()) bind_iv = bind_thread;
-    }
-
-    Range dom = dom_map.at(iv);
-
-    // initialize the offset and loop_level
-    Var var = bind_iv->var;
-
-    // Mark the iter var in the IR, to remember the point
-    if (bind_iv->thread_tag.length() == 0) {
-      // Only generate new loop if we're not bound to a thread.
-      if (new_loop_var) {
-        var = Var(iv->var->name_hint + ".init", bind_iv->var.dtype());
-      }
-
-      ForType for_type = ForType::Serial;
-      IterVarAttr it_attr;
-      if (stage->iter_var_attrs.count(iv)) {
-        it_attr = stage->iter_var_attrs[iv];
-      }
-      if (it_attr.defined()) {
-        switch (it_attr->iter_type) {
-          case kUnrolled: for_type = ForType::Unrolled; break;
-          case kVectorized: for_type = ForType::Vectorized; break;
-          case kParallelized: for_type = ForType::Parallel; break;
-          case kDataPar: break;
-          case kTensorized: break;
-          default: LOG(FATAL) << "Unknown iter type"
-                              << it_attr->iter_type
-                              << " in the iter_var_attrs";
-        }
-        CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
-        for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
-          const std::string& pkey = it_attr->pragma_keys[k].as<StringImmNode>()->value;
-          PrimExpr pvalue = it_attr->pragma_values[k];
-          if (!pvalue.defined()) {
-            pvalue = make_const(DataType::Int(32), 1);
-          }
-          nest[i + 1].emplace_back(
-              AttrStmtNode::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
-        }
-      }
-      if (!debug_keep_trivial_loop && is_one(dom->extent)) {
-        nest[i + 1].emplace_back(
-            LetStmtNode::make(var, dom->min, no_op));
-        value_map[iv] = dom->min;
-      } else if (is_zero(dom->min)) {
-        nest[i + 1].emplace_back(
-            ForNode::make(var, 0, dom->extent,
-                      for_type, DeviceAPI::None, no_op));
-        value_map[iv] = var;
-      } else {
-        Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
-        nest[i + 1].emplace_back(
-            ForNode::make(idx, 0, dom->extent,
-                      for_type, DeviceAPI::None, no_op));
-        PrimExpr new_value = dom->min + idx;
-        value_map[iv] = new_value;
-        nest[i + 1].emplace_back(
-            LetStmtNode::make(var, new_value, no_op));
-      }
-      if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
-        CHECK(!is_one(dom->extent))
-            << "Cannot prefetch on trivial loop with extent=1";
-        CHECK_EQ(it_attr->prefetch_data.size(),
-                 it_attr->prefetch_offset.size());
-        for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
-          nest[i + 1].emplace_back(
-              AttrStmtNode::make(it_attr->prefetch_data[j],
-                             ir::attr::prefetch_scope,
-                             it_attr->prefetch_offset[j], no_op));
-        }
-      }
-    } else if (bind_iv->thread_tag == "vthread" ||
-               bind_iv->thread_tag == "cthread") {
-      // virtual thread
-      // Always restrict threaded IterVar to starts from 0.
-      CHECK(is_zero(dom->min));
-      CHECK(is_positive_const(dom->extent));
-      // annotate the extent of the IterVar
-      nest[i + 1].emplace_back(
-          AttrStmtNode::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
-      value_map[iv] = var;
-    } else if (bind_iv->thread_tag == "pipeline") {
-      // pipeline marker.
-      CHECK(is_zero(dom->min));
-      CHECK(is_one(dom->extent));
-      // annotate the extent of the IterVar
-      nest[i + 1].emplace_back(
-          AttrStmtNode::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
-      value_map[iv] = dom->min;
-    } else {
-      // Always restrict threaded IterVar to starts from 0.
-      CHECK(is_zero(dom->min));
-      // annotate the extent of the IterVar
-      nest[i + 1].emplace_back(
-          AttrStmtNode::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
-      if (!debug_keep_trivial_loop && is_one(dom->extent)) {
-        value_map[iv] = dom->min;
-      } else {
-        value_map[iv] = var;
-      }
-    }
-    // annotate the extent of the IterVar
-    if (!new_loop_var) {
-      nest[i + 1].emplace_back(
-          AttrStmtNode::make(iv, attr::loop_scope, iv->var, no_op));
-    }
-  }
-  // message passing to get offset of root iter vars.
-  schedule::PassUpIndex(stage, dom_map, &value_map);
-  return nest;
-}
-
-std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates) {
-  Stmt no_op = EvaluateNode::make(0);
-  std::vector<Stmt> nest;
-  for (const PrimExpr& cond : predicates) {
-    nest.emplace_back(IfThenElseNode::make(cond, no_op));
-  }
-  return nest;
-}
-
-// replacer to replace tensors
-class TensorReplacer : public ir::StmtExprMutator {
- public:
-  explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
-      : vmap_(vmap) {}
-
-  PrimExpr VisitExpr_(const ir::CallNode* op) final {
-    if (op->call_type == ir::CallNode::Halide) {
-      Tensor t = Downcast<Operation>(op->func).output(op->value_index);
-      auto it = vmap_.find(t);
-      if (it != vmap_.end()) {
-        PrimExpr ret = ir::CallNode::make(
-            op->dtype, it->second->op->name, op->args,
-            op->call_type, it->second->op, it->second->value_index);
-        found = true;
-        return this->VisitExpr(ret);
-      }
-    }
-    return StmtExprMutator::VisitExpr_(op);
-  }
-
-  // whether it is found.
-  bool found{false};
-
- private:
-  const std::unordered_map<Tensor, Tensor>& vmap_;
-};
-
-Stmt ReplaceTensor(Stmt stmt,
-                   const std::unordered_map<Tensor, Tensor>& replace) {
-  TensorReplacer repl(replace);
-  Stmt ret = repl(stmt);
-  return repl.found ? ret : stmt;
-}
-PrimExpr ReplaceTensor(PrimExpr expr,
-                   const std::unordered_map<Tensor, Tensor>& replace) {
-  TensorReplacer repl(replace);
-  PrimExpr ret = repl(expr);
-  return repl.found ? ret : expr;
-}
-
-
-Stmt Substitute(Stmt s,
-                const std::unordered_map<IterVar, PrimExpr>& value_map) {
-  std::unordered_map<const VarNode*, PrimExpr> init;
-  for (const auto& kv : value_map) {
-    init[kv.first->var.get()] = kv.second;
-  }
-  return ir::Substitute(s, init);
-}
-
-IterVarType ForTypeToIterVarType(ir::ForType for_type) {
-  switch (for_type) {
-  case ForType::Serial:
-    return kDataPar;
-  case ForType::Parallel:
-    return kParallelized;
-  case ForType::Vectorized:
-    return kVectorized;
-  case ForType::Unrolled:
-    return kUnrolled;
-  default:
-    return kDataPar;
-  }
-}
-
-ir::ForType IterVarTypeToForType(IterVarType iter_type) {
-  switch (iter_type) {
-  case kDataPar:
-    return ForType::Serial;
-  case kParallelized:
-    return ForType::Parallel;
-  case kVectorized:
-    return ForType::Vectorized;
-  case kUnrolled:
-    return ForType::Unrolled;
-  default:
-    return ForType::Serial;
-  }
-}
-
-}  // namespace op
-}  // namespace tvm
diff --git a/src/op/op_util.h b/src/op/op_util.h
deleted file mode 100644 (file)
index cea050b..0000000
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file op_util.h
- * \brief Common utility used in operator construction.
- */
-#ifndef TVM_OP_OP_UTIL_H_
-#define TVM_OP_OP_UTIL_H_
-
-#include <tvm/expr.h>
-#include <tvm/schedule.h>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-#include "../pass/ir_util.h"
-#include "../pass/arg_binder.h"
-#include "../schedule/message_passing.h"
-
-namespace tvm {
-namespace op {
-
-using ir::MergeNest;
-
-/*!
- * \brief Build loop nest for stage.
- *
- * \param stage The stage to create a loop nest.
- * \param dom_map The range of each iter var.
- * \param begin_iter_pos The beginning position of leaf_iter_vars to generate loop.
- * \param new_loop_var Whether create new loop variable.
- * \param skip_iter Whether skip certain iteration.
- * \param p_value_map The result value of each IterVar.
- * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
- */
-std::vector<std::vector<Stmt> >
-MakeLoopNest(const Stage& stage,
-             const std::unordered_map<IterVar, Range>& dom_map,
-             size_t begin_iter_pos,
-             bool new_loop_var,
-             const std::unordered_set<IterVar>& skip_iter,
-             std::unordered_map<IterVar, PrimExpr>* p_value_map,
-             bool debug_keep_trivial_loop);
-
-/*!
- * \brief Create a nest of if checking the predicates.
- *
- * \param predicates The predicates to be checked.
- * \return List of If nest that checks the predicates.
- */
-std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates);
-
-/*!
- * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
- * \param stmt The statement to be processed.
- * \param replace The replacement rule.
- */
-Stmt ReplaceTensor(Stmt stmt,
-                   const std::unordered_map<Tensor, Tensor>& replace);
-/*!
- * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
- * \param expr The expression to be processed.
- * \param replace The replacement rule.
- */
-PrimExpr ReplaceTensor(PrimExpr expr,
-                   const std::unordered_map<Tensor, Tensor>& replace);
-
-/*!
- * \brief Substitute the variables of stmt by value map.
- * \param stmt the statment
- * \param value_map The value map.
- * \return Substituted result.
- */
-Stmt Substitute(Stmt stmt,
-                const std::unordered_map<IterVar, PrimExpr>& value_map);
-
-/*!
- * \brief Converts Halide ForType to its corresponding IterVarType
- * \param for_type The ForType to be converted
- */
-IterVarType ForTypeToIterVarType(ir::ForType for_type);
-
-/*!
- * \brief Converts IterVarType to its corresponding Halide ForType
- * \param iter_type The IterVarType to be converted
- */
-ir::ForType IterVarTypeToForType(IterVarType iter_type);
-
-}  // namespace op
-}  // namespace tvm
-#endif  // TVM_OP_OP_UTIL_H_
diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc
deleted file mode 100644 (file)
index 22e0ad4..0000000
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Placeholder op.
- * \file placeholder_op.cc
- */
-#include <tvm/operation.h>
-
-namespace tvm {
-
-// PlaceholderOpNode
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const PlaceholderOpNode*>(node.get());
-    p->stream << "placeholder(" << op->name << ", " << op << ")";
-});
-
-TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
-
-int PlaceholderOpNode::num_outputs() const {
-  return 1;
-}
-
-Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
-  return {};
-}
-
-DataType PlaceholderOpNode::output_dtype(size_t i) const {
-  CHECK_EQ(i, 0U);
-  return dtype;
-}
-
-Array<PrimExpr> PlaceholderOpNode::output_shape(size_t i) const {
-  CHECK_EQ(i, 0U);
-  return shape;
-}
-
-Operation PlaceholderOpNode::make(std::string name,
-                                  Array<PrimExpr> shape,
-                                  DataType dtype) {
-  auto n = make_object<PlaceholderOpNode>();
-  n->name = name;
-  n->shape = shape;
-  n->dtype = dtype;
-  return Operation(n);
-}
-
-Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
-  return PlaceholderOpNode::make(name, shape, dtype).output(0);
-}
-
-Array<Tensor> PlaceholderOpNode::InputTensors() const {
-  return {};
-}
-
-Operation PlaceholderOpNode::ReplaceInputs(
-    const Operation& self,
-    const std::unordered_map<Tensor, Tensor>& rmap) const {
-  return self;
-}
-
-void PlaceholderOpNode::PropBoundToInputs(
-    const Operation& self,
-    arith::Analyzer* analyzer,
-    const std::unordered_map<const VarNode*, IntSet>& dom_map,
-    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
-}
-
-void PlaceholderOpNode::GatherBound(
-    const Operation& self,
-    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
-    std::unordered_map<IterVar, Range>* out_dom_map) const {
-}
-
-Stmt PlaceholderOpNode::BuildRealize(
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& realize_map,
-    const Stmt& body) const {
-  return body;
-}
-
-Stmt PlaceholderOpNode::BuildProvide(
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    bool debug_keep_trivial_loop) const {
-  return Stmt();
-}
-}  // namespace tvm
diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc
deleted file mode 100644 (file)
index f7b16f2..0000000
+++ /dev/null
@@ -1,308 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Scan Operator.
- * \file scan_op.cc
- */
-#include <tvm/operation.h>
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include "op_util.h"
-#include "../schedule/graph.h"
-
-namespace tvm {
-
-using namespace ir;
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<ScanOpNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const ScanOpNode*>(node.get());
-    p->stream << "scan(" << op->name << ", " << op << ")";
-});
-TVM_REGISTER_NODE_TYPE(ScanOpNode);
-
-inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) {
-  return is_zero(ir::Simplify(lhs - rhs));
-}
-
-int ScanOpNode::num_outputs() const {
-  return static_cast<int>(update.size());
-}
-Array<IterVar> ScanOpNode::root_iter_vars() const {
-  Array<IterVar> ret{scan_axis};
-  for (IterVar iv : spatial_axis_) {
-    ret.push_back(iv);
-  }
-  return ret;
-}
-
-DataType ScanOpNode::output_dtype(size_t i) const {
-  return update[i]->dtype;
-}
-
-Array<PrimExpr> ScanOpNode::output_shape(size_t i) const {
-  CHECK_LT(i, state_placeholder.size());
-  return state_placeholder[i]->shape;
-}
-
-Operation ScanOpNode::make(std::string name,
-                           std::string tag,
-                           Map<std::string, ObjectRef> attrs,
-                           IterVar axis,
-                           Array<Tensor> init,
-                           Array<Tensor> update,
-                           Array<Tensor> state_placeholder,
-                           Array<Tensor> inputs) {
-  if (!attrs.defined()) {
-    attrs = Map<std::string, ObjectRef>();
-  }
-  auto n = make_object<ScanOpNode>();
-  CHECK_EQ(init.size(), update.size());
-  CHECK_EQ(init.size(), state_placeholder.size());
-
-  for (size_t i = 0; i < init.size(); ++i) {
-    CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
-    CHECK_EQ(init[i]->dtype, update[i]->dtype);
-    CHECK(prove_equal(init[i]->shape[0], axis->dom->min))
-        << "init.shape[0] need to match scan_axis.dom.min";
-    CHECK(prove_equal(
-        state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
-        << "state_placeholder.shape[0] need to match"
-        << " scan_axis.dom.min + scan_axis.dom.extent";
-    CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
-        << "The dimension of init need to match state_placeholder";
-    CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
-        << "The update.ndim need to be state_placeholder.ndim - 1";
-    for (size_t k = 0;  k < update[i].ndim(); ++k) {
-      CHECK(prove_equal(
-          update[i]->shape[k], state_placeholder[i]->shape[k]));
-      if (k != 0) {
-        // setup spatial axis
-        std::ostringstream spatial_name;
-        spatial_name << name << ".out" << i << ".i" << k;
-        n->spatial_axis_.push_back(
-            IterVarNode::make(
-                Range::make_by_min_extent(0, update[i]->shape[k]),
-                Var(spatial_name.str()), kOpaque));
-      }
-    }
-
-    for (size_t k = 1;  k < init[i].ndim(); ++k) {
-      CHECK(prove_equal(
-          init[i]->shape[k], state_placeholder[i]->shape[k]));
-    }
-  }
-  n->name = std::move(name);
-  n->tag = std::move(tag);
-  n->attrs = std::move(attrs);
-  n->scan_axis = std::move(axis);
-  n->init = std::move(init);
-  n->update = std::move(update);
-  n->state_placeholder = std::move(state_placeholder);
-  n->inputs = std::move(inputs);
-  return Operation(n);
-}
-
-Array<Tensor> scan(Array<Tensor> init,
-                   Array<Tensor> update,
-                   Array<Tensor> state_placeholder,
-                   Array<Tensor> inputs,
-                   std::string name,
-                   std::string tag,
-                   Map<std::string, ObjectRef> attrs) {
-  IterVar scan_axis =
-      IterVarNode::make(
-          Range::make_by_min_extent(
-              init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
-          Var(name + ".idx"), kOrdered);
-  Operation op = ScanOpNode::make(
-      name, tag, attrs, scan_axis,
-      init, update, state_placeholder, inputs);
-  Array<Tensor> res;
-  for (int i = 0; i < op->num_outputs(); ++i) {
-    res.push_back(op.output(i));
-  }
-  return res;
-}
-
-Array<Tensor> ScanOpNode::InputTensors() const {
-  Array<Tensor> ret;
-  for (Tensor t : init) {
-    ret.push_back(t);
-  }
-  for (Tensor t : update) {
-    ret.push_back(t);
-  }
-  return ret;
-}
-
-Operation ScanOpNode::ReplaceInputs(
-    const Operation& self,
-    const std::unordered_map<Tensor, Tensor>& rmap) const {
-  CHECK_EQ(self.operator->(), this);
-  auto n = make_object<ScanOpNode>(*this);
-  for (size_t i = 0; i < n->init.size(); ++i) {
-    if (rmap.count(n->init[i])) {
-      n->init.Set(i, rmap.at(n->init[i]));
-    }
-    if (rmap.count(n->update[i])) {
-      n->update.Set(i, rmap.at(n->update[i]));
-    }
-  }
-  if (!n->init.same_as(init) ||
-      !n->update.same_as(update)) {
-    return Operation(n);
-  } else {
-    return self;
-  }
-}
-
-void ScanOpNode::PropBoundToInputs(
-    const Operation& self,
-    arith::Analyzer* analyzer,
-    const std::unordered_map<const VarNode*, IntSet>& dom_map,
-    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
-  CHECK_EQ(self.operator->(), this);
-  for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) {
-    TensorDom* init_dom = nullptr;
-    TensorDom* update_dom = nullptr;
-    if (out_dom_map->count(this->init[i])) {
-      init_dom = &out_dom_map->at(this->init[i]);
-    }
-    if (out_dom_map->count(this->update[i])) {
-      update_dom = &out_dom_map->at(this->update[i]);
-    }
-    // first dimension, always needed.
-    if (init_dom) {
-      init_dom->data[0].push_back(IntSet::range(
-          Range::make_by_min_extent(0, this->init[i]->shape[0])));
-    }
-    if (update_dom) {
-      update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
-    }
-    // The update dimensions
-    for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
-      IterVar sp_ax = this->spatial_axis_[sp_idx];
-      if (init_dom) {
-        init_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
-      }
-      if (update_dom) {
-        update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
-      }
-    }
-  }
-}
-
-void ScanOpNode::GatherBound(
-    const Operation& self,
-    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
-    std::unordered_map<IterVar, Range>* out_dom_map) const {
-  CHECK_EQ(self.operator->(), this);
-  using namespace schedule;
-  CHECK(!out_dom_map->count(this->scan_axis));
-  std::vector<Tensor> output(this->num_outputs());
-  for (size_t i = 0; i < output.size(); ++i) {
-    output[i] = self.output(i);
-  }
-  // Update for time axis.
-  std::vector<IntSet> time_dom;
-  for (size_t i = 0; i < output.size(); ++i) {
-    const TensorDom& d = tensor_dom.at(output[i]);
-    time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
-  }
-  CHECK(!out_dom_map->count(this->scan_axis));
-  Range sdom = this->scan_axis->dom;
-  Range r = arith::Union(time_dom).cover_range(sdom);
-  (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
-      sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
-  Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self);
-  // Update for spatial axis.
-  size_t sp_idx = 0;
-  for (size_t i = 0; i < output.size(); ++i) {
-    const TensorDom& d = tensor_dom.at(output[i]);
-    for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
-      IterVar sp_ax = this->spatial_axis_[sp_idx];
-      CHECK(!out_dom_map->count(sp_ax));
-      CHECK(fix_pt.count(sp_ax));
-      if (fix_pt[sp_ax].as<ir::IntImmNode>()->value) {
-        // fix point, we can slice it.
-        (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom);
-      } else {
-        // not a fix point, need to include everything.
-        (*out_dom_map)[sp_ax] = sp_ax->dom;
-      }
-    }
-  }
-}
-
-Stmt ScanOpNode::BuildRealize(
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    const Stmt& body) const {
-  CHECK_EQ(stage->op.get(), this);
-  Range sdom = dom_map.at(this->scan_axis);
-  Range tdom = Range::make_by_min_extent(
-      0, ir::Simplify(sdom->extent + sdom->min));
-  Stmt ret = body;
-  size_t sp_idx = 0;
-  for (size_t i = 0; i < update.size(); ++i) {
-    Tensor t = stage->op.output(i);
-    CHECK_EQ(static_cast<size_t>(t->value_index), i);
-    Region bounds;
-    bounds.push_back(tdom);
-    for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
-      IterVar sp_ax = this->spatial_axis_[sp_idx];
-      bounds.push_back(dom_map.at(sp_ax));
-    }
-    ret = ir::RealizeNode::make(t->op, t->value_index, t->dtype,
-                            bounds, const_true(), ret);
-  }
-  return ret;
-}
-
-Stmt ScanOpNode::BuildProvide(
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    bool debug_keep_trivial_loop) const {
-  CHECK_EQ(stage->op.operator->(), this);
-  Stmt provide = AttrStmtNode::make(
-      stage->op, attr::scan_update_scope, this->scan_axis->var,
-      EvaluateNode::make(0));
-  Stmt init = AttrStmtNode::make(
-      stage->op, attr::scan_init_scope, 0,
-      EvaluateNode::make(0));
-  size_t begin_scan = 0;
-  for (size_t  i = 0; i < stage->leaf_iter_vars.size(); ++i) {
-    if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
-      CHECK_EQ(begin_scan, i);
-      begin_scan = i + 1;
-    }
-  }
-  std::unordered_map<IterVar, PrimExpr> vmap;
-  std::unordered_set<IterVar> empty;
-  auto nest = op::MakeLoopNest(
-      stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
-  nest[begin_scan].push_back(init);
-  nest.push_back(
-      op::MakeIfNest(
-          schedule::MakeBoundCheck(stage, dom_map, vmap, false, empty)));
-  return MergeNest(nest, provide);
-}
-}  // namespace tvm
diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc
deleted file mode 100644 (file)
index 4c758dd..0000000
+++ /dev/null
@@ -1,268 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Tensor Compute Op.
- * \file tensor_compute_op.cc
- */
-#include <tvm/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <unordered_set>
-#include "./op_util.h"
-#include "./compute_op.h"
-#include "../arith/compute_expr.h"
-
-namespace tvm {
-using namespace ir;
-// TensorComputeOpNode
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const TensorComputeOpNode*>(node.get());
-    p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
-  });
-
-TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);
-
-int TensorComputeOpNode::num_outputs() const {
-  return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
-}
-
-DataType TensorComputeOpNode::output_dtype(size_t i) const {
-  return this->intrin->buffers[this->inputs.size() + i]->dtype;
-}
-
-Operation TensorComputeOpNode::make(std::string name,
-                                    std::string tag,
-                                    Array<IterVar> axis,
-                                    Array<IterVar> reduce_axis,
-                                    int schedulable_ndim,
-                                    TensorIntrin intrin,
-                                    Array<Tensor> tensors,
-                                    Array<Region> regions,
-                                    Array<PrimExpr> scalar_inputs) {
-  auto n = make_object<TensorComputeOpNode>();
-  n->name = std::move(name);
-  n->tag = std::move(tag);
-  n->axis = std::move(axis);
-  n->reduce_axis = std::move(reduce_axis);
-  n->schedulable_ndim = std::move(schedulable_ndim);
-  n->intrin = std::move(intrin);
-  n->inputs = std::move(tensors);
-  n->input_regions = std::move(regions);
-  n->scalar_inputs = std::move(scalar_inputs);
-  return Operation(n);
-}
-
-Array<Tensor> TensorComputeOpNode::InputTensors() const {
-  return inputs;
-}
-
-Operation TensorComputeOpNode::ReplaceInputs(
-    const Operation& self,
-    const std::unordered_map<Tensor, Tensor>& rmap) const {
-  CHECK_EQ(self.operator->(), this);
-  auto n = make_object<TensorComputeOpNode>(*this);
-  auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
-  intrin->body = op::ReplaceTensor(this->intrin->body, rmap);
-  if (intrin->reduce_init.defined()) {
-    intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap);
-  }
-  if (intrin->reduce_update.defined()) {
-    intrin->reduce_update = op::ReplaceTensor(this->intrin->reduce_update, rmap);
-  }
-  for (size_t i = 0; i < n->inputs.size(); ++i) {
-    Tensor t = n->inputs[i];
-    if (rmap.count(t)) {
-      n->inputs.Set(i, rmap.at(t));
-    }
-  }
-
-  if (intrin->body.same_as(n->intrin->body) &&
-      intrin->reduce_init.same_as(n->intrin->reduce_init) &&
-      intrin->reduce_update.same_as(n->intrin->reduce_update) &&
-      inputs.same_as(n->inputs)) {
-    return self;
-  } else {
-    n->intrin = TensorIntrin(intrin);
-    return Operation(n);
-  }
-}
-
-void TensorComputeOpNode::PropBoundToInputs(
-    const Operation& self,
-    arith::Analyzer* analyzer,
-    const std::unordered_map<const VarNode*, IntSet>& dom_map,
-    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
-  for (size_t i = 0; i < this->inputs.size(); ++i) {
-    Tensor t = this->inputs[i];
-    Region region = input_regions[i];
-
-    auto it = out_dom_map->find(t);
-    if (it == out_dom_map->end()) continue;
-    TensorDom& dom = it->second;
-    for (size_t j = 0; j < t.ndim(); ++j) {
-      dom.data[j].emplace_back(EvalSet(region[j], dom_map));
-    }
-  }
-}
-
-size_t TensorComputeOpNode::num_schedulable_dims() const {
-  return schedulable_ndim;
-}
-
-Stmt TensorComputeOpNode::BuildProvide(
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    bool debug_keep_trivial_loop) const {
-  CHECK_EQ(stage->op.operator->(), this);
-
-  // Start bind data.
-  Stmt nop = EvaluateNode::make(0);
-  std::vector<Stmt> input_bind_nest, output_bind_nest;
-  Array<Tensor> inputs = this->InputTensors();
-
-  // input binding
-  size_t num_inputs = inputs.size();
-  for (size_t i = 0; i < num_inputs; ++i) {
-    Tensor tensor = inputs[i];
-    Region region = this->input_regions[i];
-    Buffer buffer = this->intrin->buffers[i];
-    Array<ObjectRef> bind_spec{buffer, tensor};
-
-    Array<PrimExpr> tuple;
-    for (size_t i = 0; i < region.size(); ++i) {
-      tuple.push_back(region[i]->min);
-      tuple.push_back(region[i]->extent);
-    }
-    input_bind_nest.emplace_back(AttrStmtNode::make(
-        bind_spec, ir::attr::buffer_bind_scope,
-        CallNode::make(DataType::Handle(),
-                       ir::intrinsic::tvm_tuple,
-                       tuple, CallNode::Intrinsic), nop));
-  }
-
-  // output binding
-  for (int i = 0; i < this->num_outputs(); ++i) {
-    Tensor tensor = stage->op.output(i);
-    Buffer buffer = this->intrin->buffers[num_inputs + i];
-    Array<ObjectRef> bind_spec{buffer, tensor};
-
-    Array<PrimExpr> tuple;
-    for (size_t i = 0; i < this->axis.size(); ++i) {
-      auto ivar = this->axis[i];
-      if (i < static_cast<size_t>(this->schedulable_ndim)) {
-        tuple.push_back(ivar->var);
-        tuple.push_back(1);
-      } else {
-        Range dom = ivar->dom;
-        tuple.push_back(dom->min);
-        tuple.push_back(dom->extent);
-      }
-    }
-
-    output_bind_nest.emplace_back(AttrStmtNode::make(
-        bind_spec, ir::attr::buffer_bind_scope,
-        CallNode::make(DataType::Handle(),
-                       ir::intrinsic::tvm_tuple,
-                       tuple, CallNode::Intrinsic), nop));
-  }
-
-  // Check variable remap
-  std::unordered_map<const VarNode*, PrimExpr> vmap;
-  ir::ArgBinder binder(&vmap);
-
-  // Map the expressions passed in the call to the TensorIntrin, to the placeholder
-  // variables
-  Array<PrimExpr> user_expr = this->scalar_inputs;
-  Array<Var> scalar_params = this->intrin->scalar_params;
-  Array<PrimExpr> sp_expr;
-  for (auto sp : scalar_params) {
-    PrimExpr esp = sp;
-    sp_expr.push_back(esp);
-  }
-  CHECK_EQ(sp_expr.size(), user_expr.size());
-  // TODO(jdavies-huawei): what name should be used here?
-  binder.BindArray(sp_expr, user_expr, this->name);
-
-  size_t tloc = stage->leaf_iter_vars.size();
-  ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
-
-  if (this->reduce_axis.size() == 0) {
-    std::vector<std::vector<Stmt> > nest(
-        n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
-    nest.emplace_back(op::MakeIfNest(n.main_predicates));
-    CHECK_EQ(n.init_predicates.size(), 0U);
-    CHECK(this->intrin->body.defined())
-        << "Normal store op for intrin " << this << " is not defined";
-    Stmt body = MergeNest(output_bind_nest, this->intrin->body);
-    body = MergeNest(input_bind_nest, body);
-    body = ir::Substitute(body, vmap);
-    body = MergeNest(binder.asserts(), body);
-    body = op::Substitute(body, n.main_vmap);
-    Stmt ret =  MergeNest(nest, body);
-    return ret;
-  } else {
-    // Need to split reduction
-    CHECK(this->intrin->reduce_update.defined())
-        << "Reduction update op is not defined";
-    // Need init and update steps
-    CHECK_NE(this->reduce_axis.size(), 0U);
-    std::vector<std::vector<Stmt> > common(
-        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
-    std::vector<std::vector<Stmt> > update_nest(
-        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
-    update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
-
-    if (this->intrin->reduce_init.defined()) {
-      // init nest
-      std::vector<std::vector<Stmt> > init_nest(
-          n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
-      init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
-      Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
-      init = op::Substitute(init, n.init_vmap);
-      init = MergeNest(init_nest, init);
-      // The update
-      Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
-      update = MergeNest(input_bind_nest, update);
-      update = ir::Substitute(update, vmap);
-      update = MergeNest(binder.asserts(), update);
-      update = op::Substitute(update, n.main_vmap);
-      update = MergeNest(update_nest, update);
-      return MergeNest(common, SeqStmt::Flatten(init, update));
-    } else {
-      // When init op is not available, use body op for reset in the first iter.
-      CHECK(this->intrin->body.defined())
-          << "Normal body op is not defined";
-      Stmt update = TransformUpdate(stage, dom_map, n,
-                                    this->intrin->body,
-                                    this->intrin->reduce_update);
-      update = MergeNest(output_bind_nest, update);
-      update = MergeNest(input_bind_nest, update);
-      update = ir::Substitute(update, vmap);
-      update = MergeNest(binder.asserts(), update);
-      update = op::Substitute(update, n.main_vmap);
-      update = MergeNest(update_nest, update);
-      return MergeNest(common, update);
-    }
-  }
-}
-
-}  // namespace tvm
diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc
deleted file mode 100644 (file)
index 4460d90..0000000
+++ /dev/null
@@ -1,536 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \brief Logics related to tensorize, used by ComputeOpNode.
- * \file tensorize.cc
- */
-#include <tvm/ir.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/ir_pass.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
-#include "op_util.h"
-#include "compute_op.h"
-#include "../schedule/message_passing.h"
-
-namespace tvm {
-
-using namespace ir;
-using namespace op;
-
-// Detect the region of input and output to be tensrized.
-// out_dom: the domain of root iter vars in output op
-// in_region: region of each input tensor.
-// return The location of the tensorized scope start.
-size_t InferTensorizeRegion(
-    const ComputeOpNode* self,
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    std::unordered_map<IterVar, Range>* out_dom,
-    std::unordered_map<Tensor, Array<Range> >* in_region) {
-  // Get the bound of the tensorized scope.
-  bool found_point = false;
-  size_t loc_scope = 0;
-  std::unordered_map<IterVar, IntSet> up_state;
-  // Loop over the leafs
-  for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) {
-    IterVar iv = stage->leaf_iter_vars[i - 1];
-    CHECK(iv->iter_type == kDataPar ||
-          iv->iter_type == kCommReduce);
-    auto vit = dom_map.find(iv);
-    CHECK(vit != dom_map.end());
-    const Range& vrange = vit->second;
-    if (is_one(vrange->extent)) {
-      up_state[iv] = IntSet::single_point(vrange->min);
-    } else if (found_point) {
-      CHECK(is_zero(vrange->min));
-      up_state[iv] = IntSet::single_point(iv->var);
-    } else {
-      up_state[iv] = IntSet::range(vrange);
-    }
-    auto iit = stage->iter_var_attrs.find(iv);
-    if (iit != stage->iter_var_attrs.end()) {
-      const IterVarAttr& attr = (*iit).second;
-      if (!found_point) {
-        CHECK(!attr->bind_thread.defined())
-            << "Do not allow thread in tensorize scope";
-      }
-      if (attr->iter_type == kTensorized) {
-        CHECK(!found_point) << "Do not allow two tensorized point";
-        found_point = true;
-        loc_scope = i - 1;
-      }
-    }
-  }
-  CHECK(found_point);
-  // Get domain of the tensorized scope.
-  schedule::PassUpDomain(stage, dom_map, &up_state);
-  // Get domains if inputs
-  std::unordered_map<Tensor, TensorDom> in_dom;
-  std::unordered_map<const VarNode*, IntSet> temp_dmap;
-  arith::Analyzer analyzer;
-  Array<Tensor> inputs = self->InputTensors();
-  for (Tensor t : inputs) {
-    in_dom.emplace(t, TensorDom(t.ndim()));
-  }
-  for (IterVar iv : self->root_iter_vars()) {
-    IntSet iset = up_state.at(iv);
-    Range iv_range = iset.cover_range(dom_map.at(iv));
-    (*out_dom)[iv] = iv_range;
-    analyzer.Bind(iv->var, iv_range);
-    temp_dmap[iv->var.get()] = iset;
-  }
-  // Input domains
-  self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom);
-  Range none;
-  for (const auto& kv : in_dom) {
-    Array<Range> vec;
-    const Tensor& t = kv.first;
-    for (size_t i = 0; i < t.ndim(); ++i) {
-      Range r = arith::Union(kv.second.data.at(i)).cover_range(none);
-      CHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t;
-      vec.push_back(std::move(r));
-    }
-    (*in_region)[t] = std::move(vec);
-  }
-  return loc_scope;
-}
-
-void VerifyTensorizeLoopNest(const ComputeOpNode* self,
-                             const Stage& stage,
-                             const ComputeLoopNest& n,
-                             size_t tloc) {
-  // Veirfication step.
-  std::unordered_set<const VarNode*> banned;
-  CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
-  CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 ||
-        n.init_nest.size() == 0);
-  auto f_push_banned = [&banned](const Stmt& s) {
-    if (const ForNode* op = s.as<ForNode>()) {
-        banned.insert(op->loop_var.get());
-    } else if (const AttrStmtNode* op = s.as<AttrStmtNode>()) {
-      if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
-        banned.insert(iv->var.get());
-      }
-    } else if (const LetStmtNode* op = s.as<LetStmtNode>()) {
-      banned.insert(op->var.get());
-    }
-  };
-  for (size_t i = tloc; i < stage->leaf_iter_vars.size(); ++i) {
-    for (const Stmt& s : n.main_nest[i + 1]) {
-      f_push_banned(s);
-    }
-    if (n.init_nest.size() != 0) {
-      for (const Stmt& s : n.init_nest[i + 1]) {
-        f_push_banned(s);
-      }
-    }
-  }
-  for (const PrimExpr& pred : n.main_predicates) {
-    if (ir::ExprUseVar(pred, banned)) {
-      LOG(FATAL) << "Tensorize failed, split condition "
-                 << pred << " relies on var defined inside tensorize scope";
-    }
-  }
-  for (const PrimExpr& pred : n.init_predicates) {
-    if (ir::ExprUseVar(pred, banned)) {
-      LOG(FATAL) << "Tensorize failed, split condition "
-                 << pred << " relies on var defined inside tensorize scope";
-    }
-  }
-}
-
-// Remap the tensor placeholder, index and inline things.
-class TensorIntrinMatcher final : public StmtExprMutator {
- public:
-  PrimExpr VisitExpr_(const CallNode* op) final {
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<CallNode>();
-    if (op->call_type == CallNode::Halide) {
-      Tensor t = Downcast<Operation>(op->func).output(op->value_index);
-      auto it = in_remap_.find(t);
-      if (it != in_remap_.end()) {
-        const InputEntry& e = it->second;
-        CHECK_EQ(op->args.size(), e.region.size());
-        Array<PrimExpr> args;
-        for (size_t i = e.start; i < e.region.size(); ++i) {
-          args.push_back(op->args[i] - e.region[i]->min);
-        }
-        return CallNode::make(
-            op->dtype, e.tensor->op->name, args,
-            op->call_type, e.tensor->op, e.tensor->value_index);
-      }
-    }
-    return expr;
-  }
-
-  PrimExpr VisitExpr_(const VarNode* op) final {
-    auto it = var_remap_.find(op);
-    if (it != var_remap_.end()) {
-      return it->second;
-    } else {
-      return GetRef<PrimExpr>(op);
-    }
-  }
-
-  PrimExpr VisitExpr_(const ReduceNode* op) final {
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<ReduceNode>();
-    Array<IterVar> axis;
-    for (size_t i = 0; i < op->axis.size(); ++i) {
-      auto it = axis_remap_.find(op->axis[i]);
-      if (it != axis_remap_.end()) {
-        axis.push_back(it->second);
-      }
-    }
-    return ReduceNode::make(
-        op->combiner, op->source, axis, op->condition, op->value_index);
-  }
-
-  void Init(const ComputeOpNode* self,
-            const Stage& stage,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            const std::unordered_map<IterVar, Range>& out_dom,
-            const std::unordered_map<Tensor, Array<Range> >& in_region,
-            const TensorIntrin& intrin,
-            Map<Var, Range>* compute_intrin_iter_space) {
-    CHECK(self == stage->op.get());
-
-    for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
-      IterVar iv = stage->leaf_iter_vars[i];
-      auto vit = dom_map.find(iv);
-      if (vit != dom_map.end()) {
-        const Range vrange = vit->second;
-        compute_intrin_iter_space->Set(iv->var, vrange);
-      }
-    }
-
-    // input remap.
-    Array<Tensor> inputs = self->InputTensors();
-    CHECK_EQ(inputs.size(), intrin->inputs.size());
-    for (size_t i = 0; i < inputs.size(); ++i) {
-      InputEntry e;
-      e.tensor = intrin->inputs[i];
-      e.region = Array<Range>(in_region.at(inputs[i]));
-      CHECK_GE(e.region.size(), e.tensor.ndim());
-      // Enable fuzzy matching, to match [1, n, m] to [n, m]
-      e.start = e.region.size() - e.tensor.ndim();
-      for (size_t j = 0; j < e.start; ++j) {
-        auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space);
-        CHECK(is_one(canonical_extent))
-            << "Tensorize " << intrin->name << ":"
-            << " Input dimension mismatch with tensor intrin "
-            << " expected shape=" << e.tensor->shape
-            << ", given region=" << e.region;
-      }
-      in_remap_[inputs[i]] = e;
-    }
-    // output remap
-    const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
-    CHECK(intrin_compute) << "Only support compute intrinsic for now";
-    CHECK_GE(self->axis.size(), intrin_compute->axis.size())
-        << "Tensorize: Output mismatch with tensor intrin ";
-    // Enable fuzzy matching, to match [1, n, m] to [n, m]
-    size_t axis_start = self->axis.size() - intrin_compute->axis.size();
-    for (size_t i = 0; i < axis_start; ++i) {
-      Range r = out_dom.at(self->axis[i]);
-      CHECK(is_one(r->extent))
-          << "Tensorize: Output mismatch with tensor intrin "
-          << " intrin-dim=" << intrin_compute->axis.size()
-          << ", tensorize-dim=" << self->axis.size();
-      var_remap_[self->axis[i]->var.get()] = r->min;
-    }
-    // Assume we tensorize at regin axis i [min, min + extent)
-    // The corresponding intrinsic axis is j [0, extent)
-    // Remap index i to j + min
-    for (size_t i = axis_start; i < self->axis.size(); ++i) {
-      IterVar iv = self->axis[i];
-      IterVar target_iv = intrin_compute->axis[i - axis_start];
-      Range r = out_dom.at(iv);
-      var_remap_[iv->var.get()] = target_iv->var + r->min;
-      axis_remap_[iv] = target_iv;
-      compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
-    }
-    // Remap reduction axis
-    CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
-        << "Tensorize: Reduction dimension mismatch with tensor intrin";
-    axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
-    for (size_t i = 0; i < axis_start; ++i) {
-      Range r = out_dom.at(self->reduce_axis[i]);
-      CHECK(is_one(r->extent))
-          << "Tensorize: Reduction mismatch with tensor intrin "
-          << " intrin-dim=" << intrin_compute->reduce_axis.size()
-          << ", tensorize-dim=" << self->reduce_axis.size();
-      var_remap_[self->reduce_axis[i]->var.get()] = r->min;
-    }
-    for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
-      IterVar iv = self->reduce_axis[i];
-      IterVar target_iv = intrin_compute->reduce_axis[i - axis_start];
-      Range r = out_dom.at(iv);
-      var_remap_[iv->var.get()] = target_iv->var + r->min;
-      axis_remap_[iv] = target_iv;
-      compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
-    }
-  }
-
- private:
-  // Input entry
-  struct InputEntry {
-    Tensor tensor;
-    size_t start;
-    Array<Range> region;
-  };
-  // input data remap
-  std::unordered_map<Tensor, InputEntry> in_remap_;
-  // variable remap.
-  std::unordered_map<const VarNode*, PrimExpr> var_remap_;
-  // IterVar remap.
-  std::unordered_map<IterVar, IterVar> axis_remap_;
-};
-
-// Try to match tensor dataflow of the stage with the intrinsic
-Array<PrimExpr> MatchTensorizeBody(
-    const ComputeOpNode* self,
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    const std::unordered_map<IterVar, Range>& out_dom,
-    const std::unordered_map<Tensor, Array<Range> >& in_region,
-    const TensorIntrin& intrin,
-    Map<Var, Range>* compute_intrin_iter_space) {
-  TensorIntrinMatcher matcher;
-  matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
-  Array<PrimExpr> ret;
-  for (PrimExpr expr : self->body) {
-    ret.push_back(matcher(expr));
-  }
-  return ret;
-}
-
-void VerifyTensorizeBody(
-    const ComputeOpNode* self,
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map,
-    const std::unordered_map<IterVar, Range>& out_dom,
-    const std::unordered_map<Tensor, Array<Range> >& in_region,
-    const TensorIntrin& intrin) {
-  Map<Var, Range> compute_intrin_iter_space;
-  Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
-                                        &compute_intrin_iter_space);
-  const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
-  CHECK(intrin_compute) << "Only support compute intrinsic for now";
-  CHECK_EQ(body.size(), intrin_compute->body.size())
-      << "Tensorize failed: body size mismatch";
-  for (size_t i = 0; i < body.size(); ++i) {
-    PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space);
-    lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
-    PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
-    rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
-    if (lhs.dtype() != rhs.dtype()) {
-      LOG(FATAL)
-          << "Failed to match the data type with TensorIntrin "
-          << intrin->name << "'s declaration "
-          << " provided=" << lhs.dtype()
-          << ", intrin=" << rhs.dtype();
-    }
-    CHECK(Equal(lhs, rhs))
-        << "Failed to match the compute with TensorIntrin "
-        << intrin->name << "'s declaration "
-        << " provided= " << lhs
-        << ", intrin=  " << rhs;
-  }
-}
-
-Stmt MakeTensorize(const ComputeOpNode* self,
-                   const Stage& stage,
-                   const std::unordered_map<IterVar, Range>& dom_map,
-                   bool debug_keep_trivial_loop) {
-  std::unordered_map<IterVar, Range> out_dom;
-  std::unordered_map<Tensor, Array<Range> > in_region;
-  size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
-  TensorIntrin intrin = stage->iter_var_attrs.at(
-      stage->leaf_iter_vars[tloc])->tensor_intrin;
-  CHECK(intrin.defined());
-  ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
-  VerifyTensorizeLoopNest(self, stage, n, tloc);
-  VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
-  // Start bind data.
-  Stmt nop = EvaluateNode::make(0);
-  std::vector<Stmt> input_bind_nest, output_bind_nest;
-  Array<Tensor> inputs = self->InputTensors();
-  CHECK_EQ(inputs.size(), intrin->inputs.size())
-      << "Tensorize failed: input size mismatch ";
-  // input binding
-  for (size_t i = 0; i < intrin->inputs.size(); ++i) {
-    Tensor tensor = inputs[i];
-    Buffer buffer = intrin->buffers[i];
-    Array<ObjectRef> bind_spec{buffer, tensor};
-    auto it = in_region.find(tensor);
-    CHECK(it != in_region.end());
-    const Array<Range>& region = it->second;
-    Array<PrimExpr> tuple;
-    for (const Range r : region) {
-      tuple.push_back(r->min);
-      tuple.push_back(r->extent);
-    }
-    input_bind_nest.emplace_back(AttrStmtNode::make(
-        bind_spec, ir::attr::buffer_bind_scope,
-        CallNode::make(DataType::Handle(),
-                       ir::intrinsic::tvm_tuple,
-                       tuple, CallNode::Intrinsic), nop));
-  }
-  // output binding
-  const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
-  CHECK(intrin_compute) << "Only support compute intrinsic for now";
-  CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size());
-  CHECK_EQ(intrin_compute->body.size(), self->body.size());
-  Array<PrimExpr> tuple;
-  for (IterVar iv : self->axis) {
-    auto it = out_dom.find(iv);
-    CHECK(it != out_dom.end());
-    tuple.push_back(it->second->min);
-    tuple.push_back(it->second->extent);
-  }
-  for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) {
-    Tensor tensor = stage->op.output(i - intrin->inputs.size());
-    Buffer buffer = intrin->buffers[i];
-    Array<ObjectRef> bind_spec{buffer, tensor};
-    output_bind_nest.emplace_back(AttrStmtNode::make(
-        bind_spec, ir::attr::buffer_bind_scope,
-        CallNode::make(DataType::Handle(),
-                       ir::intrinsic::tvm_tuple,
-                       tuple, CallNode::Intrinsic), nop));
-  }
-  // Check variable remap
-  std::unordered_map<const VarNode*, PrimExpr> vmap;
-  ir::ArgBinder binder(&vmap);
-  CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
-      << "Tensorization fail: reduction axis size do not match";
-  size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
-  for (size_t i = 0; i < start; ++i) {
-    IterVar iv = self->reduce_axis[i];
-    auto it = out_dom.find(iv);
-    CHECK(it != out_dom.end());
-    CHECK(is_one(it->second->extent))
-        << "Tensorization fail: reduction axis size do not match";
-  }
-  for (size_t i = start; i < self->reduce_axis.size(); ++i) {
-    IterVar iv = self->reduce_axis[i];
-    IterVar target = intrin_compute->reduce_axis[i - start];
-    auto it = out_dom.find(iv);
-    CHECK(it != out_dom.end());
-    binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0),
-                "tensir_intrin.reduction.min");
-    binder.Bind(target->dom->extent, it->second->extent,
-                "tensir_intrin.reduction.extent");
-  }
-  if (tloc <= n.num_common_loop) {
-    // Do no need to split reduction
-    std::vector<std::vector<Stmt> > nest(
-        n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
-    nest.emplace_back(op::MakeIfNest(n.main_predicates));
-    CHECK_EQ(n.init_predicates.size(), 0U);
-    CHECK(intrin->body.defined())
-        << "Normal store op for intrin " << intrin << " is not defined";
-    Stmt body = MergeNest(output_bind_nest, intrin->body);
-    body = MergeNest(input_bind_nest, body);
-    body = Substitute(body, vmap);
-    body = MergeNest(binder.asserts(), body);
-    body = Substitute(body, n.main_vmap);
-    return MergeNest(nest, body);
-  } else {
-    // Need to split reduction
-    CHECK(intrin->reduce_update.defined())
-        << "Reduction update op for intrin " << intrin << " is not defined";
-    // Need init and update steps
-    CHECK_NE(self->reduce_axis.size(), 0U);
-    std::vector<std::vector<Stmt> > common(
-        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
-    std::vector<std::vector<Stmt> > update_nest(
-        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
-    update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
-
-    if (intrin->reduce_init.defined()) {
-      // init nest
-      std::vector<std::vector<Stmt> > init_nest(
-          n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
-      init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
-      Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
-      init = Substitute(init, n.init_vmap);
-      init = MergeNest(init_nest, init);
-      // The update
-      Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
-      update = MergeNest(input_bind_nest, update);
-      update = Substitute(update, vmap);
-      update = MergeNest(binder.asserts(), update);
-      update = Substitute(update, n.main_vmap);
-      update = MergeNest(update_nest, update);
-      return MergeNest(common, SeqStmt::Flatten(init, update));
-    } else {
-      // When init op is not available, use body op for reset in the first iter.
-      CHECK(intrin->body.defined())
-          << "Normal body op for intrin " << intrin << " is not defined";
-      Stmt update = TransformUpdate(stage, dom_map, n,
-                                    intrin->body,
-                                    intrin->reduce_update);
-      update = MergeNest(output_bind_nest, update);
-      update = MergeNest(input_bind_nest, update);
-      update = Substitute(update, vmap);
-      update = MergeNest(binder.asserts(), update);
-      update = Substitute(update, n.main_vmap);
-      update = MergeNest(update_nest, update);
-      return MergeNest(common, update);
-    }
-  }
-}
-
-// Register functions for unittests
-TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    Stage stage = args[0];
-    Map<IterVar, Range> dmap = args[1];
-    std::unordered_map<IterVar, Range> out_dom;
-    std::unordered_map<Tensor, Array<Range> > in_region;
-    CHECK(stage->op.as<ComputeOpNode>());
-    InferTensorizeRegion(stage->op.as<ComputeOpNode>(),
-                         stage,
-                         as_unordered_map(dmap),
-                         &out_dom, &in_region);
-    *ret = Array<ObjectRef>{Map<IterVar, Range>(out_dom),
-                          Map<Tensor, Array<Range> >(in_region)};
-  });
-
-TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    Stage stage = args[0];
-    Map<IterVar, Range> out_dom = args[1];
-    Map<Tensor, Array<Range> > in_region = args[2];
-    TensorIntrin intrin = args[3];
-    Map<Var, Range> vrange;
-    CHECK(stage->op.as<ComputeOpNode>());
-    *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
-                              stage,
-                              {{}},
-                              as_unordered_map(out_dom),
-                              as_unordered_map(in_region),
-                              intrin,
-                              &vrange);
-  });
-}  // namespace tvm
index d87786385a02279b60979293de51a2cfefa43107..a2895d55a70de411fb99675df82440550eec27a4 100644 (file)
@@ -39,7 +39,7 @@ class PrefetchInjector : public StmtMutator {
     Stmt ret = StmtMutator::VisitStmt_(op);
     op = ret.as<AttrStmtNode>();
     if (op && op->attr_key == attr::prefetch_scope) {
-      Tensor ts = Downcast<Tensor>(op->node);
+      top::Tensor ts = Downcast<top::Tensor>(op->node);
       CHECK_NE(loop_nest_.size(), 0U);
       Domain domain = DomainTouched(op->body, ts, true, false);
       Region region;
index fc46ef3a4523dc24257f68853760d764ddb56a05..b5067658c9935f10e937dea432eb5cc2e7df8f25 100644 (file)
@@ -25,7 +25,7 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/ir.h>
 #include <tvm/expr.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/ir_functor_ext.h>
 #include <tvm/expr_operator.h>
 #include <tvm/ir_pass.h>
@@ -49,7 +49,7 @@ using intrinsic::tvm_address_of;
 
 class StorageFlattener : public StmtExprMutator {
  public:
-  explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
+  explicit StorageFlattener(Map<top::Tensor, Buffer> extern_buffer,
                             int cache_line_size, bool create_bound_attributes,
                             IRVisitorWithAnalyzer* bounded_analyzer)
       : bounded_analyzer_(bounded_analyzer),
@@ -82,8 +82,8 @@ class StorageFlattener : public StmtExprMutator {
       storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
       return this->VisitStmt(op->body);
     } else if (op->attr_key == attr::double_buffer_scope &&
-               op->node->IsInstance<OperationNode>()) {
-      Operation func = Downcast<Operation>(op->node);
+               op->node->IsInstance<top::OperationNode>()) {
+      auto func = Downcast<top::Operation>(op->node);
       Stmt body = this->VisitStmt(op->body);
       for (int i = 0; i < func->num_outputs(); ++i) {
         TensorKey key{func, i};
@@ -104,7 +104,7 @@ class StorageFlattener : public StmtExprMutator {
     } else if (op->attr_key == attr::buffer_bind_scope) {
       return HandleBufferBindScope(op);
     } else if (op->attr_key == attr::buffer_dim_align) {
-      Tensor tensor = Downcast<Tensor>(op->node);
+      auto tensor = Downcast<top::Tensor>(op->node);
       const CallNode* tuple = op->value.as<CallNode>();
       CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
       TensorKey key{tensor->op, tensor->value_index};
@@ -406,7 +406,7 @@ class StorageFlattener : public StmtExprMutator {
     Array<ObjectRef> arr = Downcast<Array<ObjectRef> > (op->node);
     CHECK_EQ(arr.size(), 2U);
     const BufferNode* buffer = arr[0].as<BufferNode>();
-    const TensorNode* tensor = arr[1].as<TensorNode>();
+    const top::TensorNode* tensor = arr[1].as<top::TensorNode>();
     const CallNode* tuple = op->value.as<CallNode>();
     CHECK(buffer && tensor);
     CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
@@ -529,7 +529,7 @@ class StorageFlattener : public StmtExprMutator {
   bool create_bound_attributes_{false};
 };
 
-Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
+Stmt StorageFlatten(Stmt stmt, Map<top::Tensor, Buffer> extern_buffer,
                     int cache_line_size, bool create_bound_attributes) {
   IRVisitorWithAnalyzer bounded_analyzer;
   bounded_analyzer(stmt);
index c5c81cacb68882dce9a013071decd7c7d61f5514..bf36b0a8ffdb662006b4d5cac69adc854f8157fb 100644 (file)
@@ -23,7 +23,7 @@
 // IR Passes for TensorCore CodeGen
 #include <tvm/ir.h>
 #include <tvm/expr.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/ir_functor_ext.h>
 #include <tvm/expr_operator.h>
 #include <tvm/ir_pass.h>
@@ -39,6 +39,7 @@
 namespace tvm {
 namespace ir {
 
+using namespace top;
 using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
@@ -417,7 +418,7 @@ class BufferAnalyser : public StmtExprVisitor {
       storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
       this->VisitStmt(op->body);
     } else if (op->attr_key == attr::buffer_dim_align) {
-      Tensor tensor = Downcast<Tensor>(op->node);
+      top::Tensor tensor = Downcast<top::Tensor>(op->node);
       const CallNode* tuple = op->value.as<CallNode>();
       CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
       auto& vinfo = dim_align_[TensorKey{tensor->op, tensor->value_index}];
@@ -831,7 +832,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
     if (op->attr_key == attr::realize_scope) {
-      auto node = op->node.as<OperationNode>();
+      auto node = op->node.as<top::OperationNode>();
       if (node != nullptr) {
         if (!frag_reg_.count(node->name)) {
           return stmt;
@@ -1119,9 +1120,9 @@ class TensorCoreIRMutator : public StmtExprMutator {
     buffer_node->offset_factor = 1;
     Buffer buffer(buffer_node);
 
-    ObjectPtr<TensorNode> tensor_node = make_object<TensorNode>();
+    ObjectPtr<top::TensorNode> tensor_node = make_object<top::TensorNode>();
     tensor_node->value_index = key.value_index;
-    tensor_node->op = Downcast<Operation>(key.f);
+    tensor_node->op = Downcast<top::Operation>(key.f);
     tensor_node->shape = shape;
     tensor_node->dtype = datatype;
     Tensor tensor(tensor_node);
index f6c454d2abf3fb9ca50d05ec137c9f8ee89f1d6b..95dcbddedddc19a8227c0c9f5845d229034317ed 100644 (file)
@@ -25,7 +25,7 @@
 #include <tvm/expr.h>
 #include <tvm/ir.h>
 #include <tvm/ir_functor_ext.h>
-#include <tvm/tensor.h>
+#include <tvm/top/tensor.h>
 
 #include <unordered_map>
 
index 5ee4ce30c96d7395208f48a3b908efd1ec1b023e..14967c15490f8326c1c2d6683108a73f561cbe12 100644 (file)
@@ -23,9 +23,9 @@
  */
 #include "compile_engine.h"
 
-#include <tvm/schedule.h>
+#include <tvm/top/schedule.h>
 #include <tvm/packed_func_ext.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/relay/attrs/device_copy.h>
 #include <tvm/relay/analysis.h>
@@ -101,20 +101,20 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
 // The getter to get schedule from compile engine.
 // Get schedule from functor.
 class ScheduleGetter :
-      public ExprFunctor<Array<Tensor>(const Expr&)> {
+      public ExprFunctor<Array<top::Tensor>(const Expr&)> {
  public:
   explicit ScheduleGetter(Target target)
       : target_(target), device_copy_op_(Op::Get("device_copy")) {}
 
-  std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
+  std::pair<top::Schedule, CachedFunc> Create(const Function& prim_func) {
     static auto fschedule =
         Op::GetAttr<FTVMSchedule>("FTVMSchedule");
     auto cache_node = make_object<CachedFuncNode>();
     cache_node->target = target_;
     for (Var param : prim_func->params) {
-      Array<tvm::Tensor> inputs;
+      Array<tvm::top::Tensor> inputs;
       if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
-        tvm::Tensor tensor = tvm::placeholder(
+        tvm::top::Tensor tensor = tvm::top::placeholder(
             GetShape(ttype->shape), ttype->dtype);
         cache_node->inputs.push_back(tensor);
         inputs.push_back(tensor);
@@ -125,7 +125,7 @@ class ScheduleGetter :
           const auto* ttype = field.as<TensorTypeNode>();
           // TODO(@icemelon): Allow recursive tuple
           CHECK(ttype != nullptr);
-          tvm::Tensor tensor = tvm::placeholder(
+          tvm::top::Tensor tensor = tvm::top::placeholder(
               GetShape(ttype->shape), ttype->dtype);
           cache_node->inputs.push_back(tensor);
           inputs.push_back(tensor);
@@ -150,13 +150,13 @@ class ScheduleGetter :
     // Fusion over tupled results may leave identity relationships
     // between inputs and outputs, and those should not be scheduled.
     // Hence schedule only non PlaceholderOp outputs.
-    tvm::Array<Tensor> tensor_outs;
+    tvm::Array<top::Tensor> tensor_outs;
     for (const auto& tensor : cache_node->outputs) {
-      if (!tensor->op.as<PlaceholderOpNode>()) {
+      if (!tensor->op.as<top::PlaceholderOpNode>()) {
         tensor_outs.push_back(tensor);
       }
     }
-    Schedule schedule;
+    top::Schedule schedule;
     // No need to register schedule for device copy op.
     if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) {
       schedule =
@@ -170,27 +170,27 @@ class ScheduleGetter :
     return std::make_pair(schedule, cfunc);
   }
 
-  Array<Tensor> VisitExpr(const Expr& expr) {
+  Array<top::Tensor> VisitExpr(const Expr& expr) {
     auto it = memo_.find(expr);
     if (it != memo_.end()) {
       return it->second;
     } else {
-      Array<Tensor> res = ExprFunctor::VisitExpr(expr);
+      Array<top::Tensor> res = ExprFunctor::VisitExpr(expr);
       memo_[expr] = res;
       return res;
     }
   }
 
-  Array<Tensor> VisitExpr_(const VarNode* op) final {
+  Array<top::Tensor> VisitExpr_(const VarNode* op) final {
     LOG(FATAL) << "Free variable " << op->name_hint();
     return {};
   }
 
-  Array<Tensor> VisitExpr_(const ConstantNode* op) final {
+  Array<top::Tensor> VisitExpr_(const ConstantNode* op) final {
     CHECK(op->is_scalar());
     void* data = op->data->data;
     DataType dtype = DataType(op->data->dtype);
-    Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+    auto value = top::compute({}, [&](const Array<tvm::Var>&) {
         if (dtype == DataType::Int(32)) {
           return make_const(dtype, static_cast<const int32_t*>(data)[0]);
         } else if (dtype == DataType::Int(64)) {
@@ -210,19 +210,19 @@ class ScheduleGetter :
     return {value};
   }
 
-  Array<Tensor> VisitExpr_(const CallNode* call_node) final {
+  Array<top::Tensor> VisitExpr_(const CallNode* call_node) final {
     static auto fcompute =
         Op::GetAttr<FTVMCompute>("FTVMCompute");
     static auto fpattern =
         Op::GetAttr<TOpPattern>("TOpPattern");
 
-    Array<Tensor> inputs;
+    Array<top::Tensor> inputs;
     int count_tuple = 0;
     for (Expr arg : call_node->args) {
       if (arg->checked_type().as<TupleTypeNode>()) {
         ++count_tuple;
       }
-      for (Tensor tensor : VisitExpr(arg)) {
+      for (top::Tensor tensor : VisitExpr(arg)) {
         inputs.push_back(tensor);
       }
     }
@@ -252,12 +252,12 @@ class ScheduleGetter :
     CHECK(call_node->op.as<OpNode>())
         << "Primitive function only allows call into primitive ops";
     Op op = Downcast<Op>(call_node->op);
-    Array<Tensor> outputs;
+    Array<top::Tensor> outputs;
     // Skip fcompute for device copy operators as it is not registered.
     if (op == device_copy_op_) {
       const auto* copy_input = inputs[0].operator->();
-      outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype,
-                                         Operation(), 0));
+      outputs.push_back(top::TensorNode::make(copy_input->shape, copy_input->dtype,
+                                         top::Operation(), 0));
     } else {
       outputs = fcompute[op](call_node->attrs, inputs,
                              call_node_type, target_);
@@ -291,33 +291,33 @@ class ScheduleGetter :
     return outputs;
   }
 
-  Array<Tensor> VisitExpr_(const FunctionNode* op) final {
+  Array<top::Tensor> VisitExpr_(const FunctionNode* op) final {
     LOG(FATAL) << "Do not support sub function";
-    return Array<Tensor>();
+    return Array<top::Tensor>();
   }
 
-  Array<Tensor> VisitExpr_(const LetNode* op) final {
-    Array<Tensor> val = VisitExpr(op->value);
+  Array<top::Tensor> VisitExpr_(const LetNode* op) final {
+    Array<top::Tensor> val = VisitExpr(op->value);
     CHECK(!memo_.count(op->var));
     memo_[op->var] = val;
     return VisitExpr(op->body);
   }
 
-  Array<Tensor> VisitExpr_(const TupleNode* op) final {
-    Array<Tensor> fields;
+  Array<top::Tensor> VisitExpr_(const TupleNode* op) final {
+    Array<top::Tensor> fields;
     for (Expr field : op->fields) {
       CHECK(field->checked_type().as<TensorTypeNode>())
           << "Only allow Tuple of Tensor";
-      Array<Tensor> res = VisitExpr(field);
+      Array<top::Tensor> res = VisitExpr(field);
       CHECK_EQ(res.size(), 1);
       fields.push_back(res[0]);
     }
     return fields;
   }
 
-  Array<Tensor> VisitExpr_(const TupleGetItemNode* op) final {
+  Array<top::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
     const auto* tuple_type = op->tuple->type_as<TupleTypeNode>();
-    Array<Tensor> tuple = VisitExpr(op->tuple);
+    Array<top::Tensor> tuple = VisitExpr(op->tuple);
     CHECK_EQ(tuple_type->fields.size(), tuple.size());
     CHECK_GE(op->index, 0);
     CHECK_LT(static_cast<size_t>(op->index), tuple.size());
@@ -330,28 +330,28 @@ class ScheduleGetter :
   Attrs master_attrs_;
   int master_op_pattern_{0};
   std::ostringstream readable_name_stream_;
-  std::unordered_map<Expr, Array<Tensor>, ObjectHash, ObjectEqual> memo_;
-  Array<Operation> scalars_;
+  std::unordered_map<Expr, Array<top::Tensor>, ObjectHash, ObjectEqual> memo_;
+  Array<top::Operation> scalars_;
   // Cache device copy op for equivalence checking to reduce registry lookup
   // overhead for each invocation of call node when retrieving schedules.
   const Op& device_copy_op_;
 };
 
 // Creates shape function from functor.
-class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
+class MakeShapeFunc : public ExprFunctor<Array<top::Tensor>(const Expr&)> {
  public:
   MakeShapeFunc() {}
 
-  std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
+  std::pair<top::Schedule, CachedFunc> Create(const Function& prim_func) {
     for (auto param : prim_func->params) {
       param_states_[param] = kNoNeed;
-      Array<tvm::Tensor> data_inputs;
-      Array<tvm::Tensor> shape_inputs;
+      Array<tvm::top::Tensor> data_inputs;
+      Array<tvm::top::Tensor> shape_inputs;
 
       auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) {
         // Add data placeholder
         Shape shape = GetShape(ttype->shape);
-        tvm::Tensor data_tensor = tvm::placeholder(shape, ttype->dtype);
+        tvm::top::Tensor data_tensor = tvm::top::placeholder(shape, ttype->dtype);
         data_inputs.push_back(data_tensor);
         // Add shape placeholder
         int64_t ndim = shape.size();
@@ -359,7 +359,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
         if (ndim > 0) {
           sshape.push_back(tvm::Integer(ndim));
         }
-        tvm::Tensor shape_tensor = tvm::placeholder(sshape, DataType::Int(64));
+        tvm::top::Tensor shape_tensor = tvm::top::placeholder(sshape, DataType::Int(64));
         shape_inputs.push_back(shape_tensor);
       };
 
@@ -410,12 +410,12 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
 
     CachedFunc cfunc(cache_node);
     // generate schedule for shape func
-    Array<Operation> out_ops;
+    Array<top::Operation> out_ops;
     for (auto t : cache_node->outputs) {
       out_ops.push_back(t->op);
     }
-    auto schedule = create_schedule(out_ops);
-    tvm::schedule::AutoInlineInjective(schedule);
+    auto schedule = top::create_schedule(out_ops);
+    tvm::top::AutoInlineInjective(schedule);
     for (const auto& scalar : scalars_) {
       auto scalar_op = scalar->op;
       if (schedule->Contain(scalar_op)) {
@@ -425,12 +425,12 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
     return std::make_pair(schedule, cfunc);
   }
 
-  Array<Tensor> VisitExpr(const Expr& expr) {
+  Array<top::Tensor> VisitExpr(const Expr& expr) {
     auto it = memo_.find(expr);
     if (it != memo_.end()) {
       return it->second;
     } else {
-      Array<Tensor> res = ExprFunctor::VisitExpr(expr);
+      Array<top::Tensor> res = ExprFunctor::VisitExpr(expr);
       if (expr.as<VarNode>() == nullptr) {
         // Do not memoize vars because shape functions could use either the data
         // or the shape of a var each time.
@@ -440,7 +440,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
     }
   }
 
-  Array<Tensor> VisitExpr_(const VarNode* var_node) final {
+  Array<top::Tensor> VisitExpr_(const VarNode* var_node) final {
     auto var = GetRef<Var>(var_node);
     auto it = param_states_.find(var);
     if (it == param_states_.end()) {
@@ -459,14 +459,14 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
     }
   }
 
-  Array<Tensor> VisitExpr_(const ConstantNode* op) final {
+  Array<top::Tensor> VisitExpr_(const ConstantNode* op) final {
     CHECK(data_dependants_.size());
     CHECK(op->is_scalar());
     bool data_dependant = data_dependants_.back();
     if (data_dependant) {
       void* data = op->data->data;
       DataType dtype = DataType(op->data->dtype);
-      Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+      auto value = tvm::top::compute({}, [&](const Array<tvm::Var>&) {
           if (dtype == DataType::Int(32)) {
             return make_const(dtype, static_cast<const int32_t*>(data)[0]);
           } else if (dtype == DataType::Int(64)) {
@@ -485,7 +485,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
       scalars_.push_back(value);
       return {value};
     } else {
-      Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+      auto value = tvm::top::compute({}, [&](const Array<tvm::Var>&) {
           return make_const(DataType::Int(64), 0);
       }, "shape_const", topi::kBroadcast);
       scalars_.push_back(value);
@@ -493,7 +493,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
     }
   }
 
-  Array<Tensor> VisitExpr_(const CallNode* call_node) final {
+  Array<top::Tensor> VisitExpr_(const CallNode* call_node) final {
     static auto fshape_func = Op::GetAttr<FShapeFunc>("FShapeFunc");
     static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>(
         "TShapeDataDependant");
@@ -510,13 +510,13 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
 
     data_dependants_.push_back(tshape_data_dependant[op]);
     // Visit all inputs
-    Array<Tensor> inputs;
+    Array<top::Tensor> inputs;
     int count_tuple = 0;
     for (Expr arg : call_node->args) {
       if (arg->checked_type().as<TupleTypeNode>()) {
         ++count_tuple;
       }
-      for (Tensor tensor : VisitExpr(arg)) {
+      for (top::Tensor tensor : VisitExpr(arg)) {
         inputs.push_back(tensor);
       }
     }
@@ -546,24 +546,24 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
     return outputs;
   }
 
-  Array<Tensor> VisitExpr_(const FunctionNode* op) final {
+  Array<top::Tensor> VisitExpr_(const FunctionNode* op) final {
     LOG(FATAL) << "Do not support sub function";
-    return Array<Tensor>();
+    return Array<top::Tensor>();
   }
 
-  Array<Tensor> VisitExpr_(const LetNode* op) final {
-    Array<Tensor> val = VisitExpr(op->value);
+  Array<top::Tensor> VisitExpr_(const LetNode* op) final {
+    Array<top::Tensor> val = VisitExpr(op->value);
     CHECK(!memo_.count(op->var));
     memo_[op->var] = val;
     return VisitExpr(op->body);
   }
 
-  Array<Tensor> VisitExpr_(const TupleNode* op) final {
-    Array<Tensor> fields;
+  Array<top::Tensor> VisitExpr_(const TupleNode* op) final {
+    Array<top::Tensor> fields;
     for (Expr field : op->fields) {
       CHECK(field->checked_type().as<TensorTypeNode>())
         << "Only allow Tuple of Tensor";
-      Array<Tensor> res = VisitExpr(field);
+      Array<top::Tensor> res = VisitExpr(field);
       CHECK_EQ(res.size(), 1);
       fields.push_back(res[0]);
     }
@@ -576,15 +576,15 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
   /*! \brief Map from parameter to its shape function usage state */
   std::unordered_map<Expr, int, ObjectHash, ObjectEqual> param_states_;
   /*! \brief Map from parameter to list of data placeholder */
-  std::unordered_map<Expr, Array<Tensor>, ObjectHash, ObjectEqual> param_data_;
+  std::unordered_map<Expr, Array<top::Tensor>, ObjectHash, ObjectEqual> param_data_;
   /*! \brief Map from parameter to list of shape placeholder */
-  std::unordered_map<Expr, Array<Tensor>, ObjectHash, ObjectEqual> param_shapes_;
+  std::unordered_map<Expr, Array<top::Tensor>, ObjectHash, ObjectEqual> param_shapes_;
   /*! \brief Memoized visit result */
-  std::unordered_map<Expr, Array<Tensor>, ObjectHash, ObjectEqual> memo_;
+  std::unordered_map<Expr, Array<top::Tensor>, ObjectHash, ObjectEqual> memo_;
   /*! \brief Stack of data dependencies for shape function */
   std::vector<bool> data_dependants_;
   /*! \brief Scalars used in the shape function */
-  Array<Tensor> scalars_;
+  Array<top::Tensor> scalars_;
 };
 
 class CompileEngineImpl : public CompileEngineNode {
@@ -672,7 +672,7 @@ class CompileEngineImpl : public CompileEngineNode {
    * \return Pair of schedule and cache.
    *  The funcs field in cache is not yet populated.
    */
-  std::pair<Schedule, CachedFunc> CreateSchedule(
+  std::pair<top::Schedule, CachedFunc> CreateSchedule(
       const Function& source_func, const Target& target) {
     return ScheduleGetter(target).Create(source_func);
   }
@@ -723,8 +723,8 @@ class CompileEngineImpl : public CompileEngineNode {
 
     cache_node->func_name = GetUniqueName(cache_node->func_name);
     // NOTE: array will copy on write.
-    Array<Tensor> all_args = cache_node->inputs;
-    for (Tensor arg : cache_node->outputs) {
+    Array<top::Tensor> all_args = cache_node->inputs;
+    for (top::Tensor arg : cache_node->outputs) {
       all_args.push_back(arg);
     }
     // lower the function
@@ -733,7 +733,7 @@ class CompileEngineImpl : public CompileEngineNode {
           spair.first, all_args, cache_node->func_name, key->source_func);
     } else {
       tvm::BuildConfig bcfg = BuildConfig::Create();
-      std::unordered_map<Tensor, Buffer> binds;
+      std::unordered_map<top::Tensor, Buffer> binds;
       cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
     }
     value->cached_func = CachedFunc(cache_node);
@@ -763,12 +763,12 @@ class CompileEngineImpl : public CompileEngineNode {
     cache_node->func_name = GetUniqueName(cache_node->func_name);
     cache_node->target = key->target;
 
-    Array<Tensor> all_args = cache_node->inputs;
-    for (Tensor arg : cache_node->outputs) {
+    Array<top::Tensor> all_args = cache_node->inputs;
+    for (top::Tensor arg : cache_node->outputs) {
       all_args.push_back(arg);
     }
     tvm::BuildConfig bcfg = BuildConfig::Create();
-    std::unordered_map<Tensor, Buffer> binds;
+    std::unordered_map<top::Tensor, Buffer> binds;
     cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
     value->cached_func = CachedFunc(cache_node);
     return value;
index f6c38ba6b9a9744dc12004e3447fdb3c522ac514..386eba7f9fd8ea5d9700857c26fbe483aa4c67be 100644 (file)
@@ -51,9 +51,9 @@ struct CachedFuncNode : public Object {
   /*! \brief Function name */
   std::string func_name;
   /* \brief The inputs to the function */
-  tvm::Array<Tensor> inputs;
+  tvm::Array<top::Tensor> inputs;
   /* \brief The outputs to the function */
-  tvm::Array<Tensor> outputs;
+  tvm::Array<top::Tensor> outputs;
   /*! \brief The lowered functions to support the function. */
   tvm::Array<tvm::LoweredFunc> funcs;
   /*! \brief Parameter usage states in the shape function. */
index 3ef740306071e948682ee2072fe41a8494a55ebd..79583685a6a286f6d49a188557aca6f85a288321 100644 (file)
@@ -30,7 +30,7 @@
 #include <tvm/build_module.h>
 #include <tvm/codegen.h>
 #include <tvm/ir_pass.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 
 #include <typeinfo>
 #include <string>
index e4a34a3974ae4afa2e0233aaa97b5764becdd306..00e47bcc9c5e4d08863e8d0e36c19309d06cdc88 100644 (file)
@@ -22,7 +22,7 @@
  * \brief A compiler from relay::Module to the VM byte code.
  */
 
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/ir/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/interpreter.h>
index 3d03f884e2470f6190969bde4de3c184072bd0f5..2aefbd7ea88352bf3c077a395a9c313e6c1f0d07 100644 (file)
@@ -78,8 +78,8 @@ TVM_ADD_FILELINE)
 .set_attr<TOpIsStateful>("TOpIsStateful", false)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
-                       [](const Attrs& attrs, const Array<Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                       [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                          const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
@@ -104,8 +104,8 @@ TVM_ADD_FILELINE)
 .set_attr<TOpIsStateful>("TOpIsStateful", false)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
-                       [](const Attrs& attrs, const Array<Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                       [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                          const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
@@ -122,8 +122,8 @@ Mark the start of bitpacking.
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
                                ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
-                       [](const Attrs& attrs, const Array<Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                       [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                          const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
@@ -139,8 +139,8 @@ Mark the end of bitpacking.
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
                                ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
-                       [](const Attrs& attrs, const Array<Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                       [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                          const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
@@ -162,9 +162,9 @@ Mark a checkpoint for checkpointing memory optimization.
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
                                ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
-                       [](const Attrs& attrs, const Array<Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<Tensor> {
-                         Array<Tensor> outputs;
+                       [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                          const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
+                         Array<top::Tensor> outputs;
                          for (size_t i = 0; i < inputs.size(); ++i) {
                            outputs.push_back(topi::identity(inputs[i]));
                          }
@@ -183,8 +183,8 @@ Beginning of a region that is handled by a given compiler.
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
                                ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
-                       [](const Attrs& attrs, const Array<Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                       [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                          const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
@@ -208,8 +208,8 @@ End of a region that is handled by a given compiler.
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
                                ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
-                       [](const Attrs& attrs, const Array<Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                       [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                          const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
index cdfdac06dbc5d3f3846dfb5c8850e5afef4e3389..87e579740892f574a9b12af0ae412e4843d39af2 100644 (file)
@@ -35,11 +35,11 @@ namespace relay {
 
 TVM_REGISTER_NODE_TYPE(DebugAttrs);
 
-Array<Tensor> DebugCompute(const Attrs& attrs,
-                           const Array<Tensor>& inputs,
+Array<top::Tensor> DebugCompute(const Attrs& attrs,
+                           const Array<top::Tensor>& inputs,
                            const Type& out_type,
                            const Target& target) {
-  return Array<Tensor>{ topi::identity(inputs[0]) };
+  return Array<top::Tensor>{ topi::identity(inputs[0]) };
 }
 
 RELAY_REGISTER_OP("debug")
index bd3b543659ae04d0ce63e9e4dc5b2b3352b10413..6c4b3ea87b0e9436c63deb6d378a658289dce9c5 100644 (file)
@@ -82,8 +82,8 @@ RELAY_REGISTER_OP("memory.alloc_storage")
     .set_attr<TNonComputational>("TNonComputational", true)
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
     .set_attr<FTVMCompute>("FTVMCompute",
-                           [](const Attrs& attrs, const Array<Tensor>& inputs,
-                              const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                           [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                              const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
@@ -178,8 +178,8 @@ RELAY_REGISTER_OP("memory.alloc_tensor")
     .set_attr<TNonComputational>("TNonComputational", true)
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
     .set_attr<FTVMCompute>("FTVMCompute",
-                           [](const Attrs& attrs, const Array<Tensor>& inputs,
-                              const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                           [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                              const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
@@ -227,8 +227,8 @@ RELAY_REGISTER_OP("memory.invoke_tvm_op")
     .set_attr<TNonComputational>("TNonComputational", true)
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
     .set_attr<FTVMCompute>("FTVMCompute",
-                           [](const Attrs& attrs, const Array<Tensor>& inputs,
-                              const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                           [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                              const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
@@ -251,8 +251,8 @@ RELAY_REGISTER_OP("memory.kill")
     .set_attr<TNonComputational>("TNonComputational", true)
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
     .set_attr<FTVMCompute>("FTVMCompute",
-                           [](const Attrs& attrs, const Array<Tensor>& inputs,
-                              const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                           [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                              const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
@@ -339,8 +339,8 @@ RELAY_REGISTER_OP("memory.shape_func")
     .set_attr<TNonComputational>("TNonComputational", true)
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
     .set_attr<FTVMCompute>("FTVMCompute",
-                           [](const Attrs& attrs, const Array<Tensor>& inputs,
-                              const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                           [](const Attrs& attrs, const Array<top::Tensor>& inputs,
+                              const Type& out_dtype, const Target& target) -> Array<top::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
index f1d711146a9959f41652ed6c10ace8c4e45735b5..9fca22d777d9aef93f67654a719ae60403c01bd5 100644 (file)
@@ -93,10 +93,10 @@ RELAY_REGISTER_OP("nn.bias_add")
 .add_argument("bias", "1D Tensor", "Bias.")
 .set_support_level(1)
 .add_type_rel("BiasAdd", BiasAddRel)
-.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<Tensor>& inputs,
+.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<top::Tensor>& inputs,
                                         const Type& out_type, const Target& target) {
     const auto* param = attrs.as<BiasAddAttrs>();
-    return tvm::Array<tvm::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
+    return tvm::Array<tvm::top::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
 });
 
 
@@ -233,11 +233,11 @@ RELAY_REGISTER_OP("nn.leaky_relu")
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>(
   "FTVMCompute", [](const Attrs& attrs,
-                    const Array<Tensor>& inputs,
+                    const Array<top::Tensor>& inputs,
                     const Type& out_type,
                     const Target& target) {
     const auto* param = attrs.as<LeakyReluAttrs>();
-    return Array<Tensor>{ topi::leaky_relu(inputs[0], param->alpha) };
+    return Array<top::Tensor>{ topi::leaky_relu(inputs[0], param->alpha) };
 });
 
 
@@ -314,11 +314,11 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>)
 .set_attr<FTVMCompute>(
   "FTVMCompute", [](const Attrs& attrs,
-                    const Array<Tensor>& inputs,
+                    const Array<top::Tensor>& inputs,
                     const Type& out_type,
                     const Target& target) {
     const auto* param = attrs.as<PReluAttrs>();
-    return Array<Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)};
+    return Array<top::Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)};
 });
 
 
@@ -350,12 +350,12 @@ RELAY_REGISTER_OP("nn.softmax")
 .set_support_level(1)
 .add_type_rel("Identity", IdentityRel)
 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
-                                         const Array<Tensor>& inputs,
+                                         const Array<top::Tensor>& inputs,
                                          const Type& out_type,
                                          const Target& target) {
   const auto* param = attrs.as<SoftmaxAttrs>();
   CHECK(param != nullptr);
-  return Array<Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
+  return Array<top::Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
 });
 
 
@@ -384,14 +384,14 @@ RELAY_REGISTER_OP("nn.log_softmax")
 .set_support_level(1)
 .add_type_rel("Identity", IdentityRel)
 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
-                                         const Array<Tensor>& inputs,
+                                         const Array<top::Tensor>& inputs,
                                          const Type& out_type,
                                          const Target& target) {
   const auto* param = attrs.as<SoftmaxAttrs>();
   CHECK(param != nullptr);
   CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
       << "log_softmax currently only works on last dimension";
-  return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) };
+  return Array<top::Tensor>{ topi::nn::log_softmax(inputs[0]) };
 });
 
 
@@ -461,10 +461,10 @@ Example::
 .add_type_rel("BatchFlatten", BatchFlattenRel)
 .set_attr<FTVMCompute>(
   "FTVMCompute", [](const Attrs& attrs,
-                    const Array<Tensor>& inputs,
+                    const Array<top::Tensor>& inputs,
                     const Type& out_type,
                     const Target& target) {
-    return Array<Tensor>{ topi::nn::flatten(inputs[0]) };
+    return Array<top::Tensor>{ topi::nn::flatten(inputs[0]) };
 });
 
 
@@ -488,10 +488,10 @@ RELAY_REGISTER_OP("nn.relu")
 .add_type_rel("Identity", IdentityRel)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
-                                         const Array<Tensor>& inputs,
+                                         const Array<top::Tensor>& inputs,
                                          const Type& out_type,
                                          const Target& target) {
-  return Array<Tensor>{ topi::relu(inputs[0], 0.0f) };
+  return Array<top::Tensor>{ topi::relu(inputs[0], 0.0f) };
 });
 
 
index 72ea70f82eb8a8447e10c75450fbedf5b104a3d3..b67f93928a88627851cacb9957cc9e3717c5897c 100644 (file)
@@ -160,8 +160,8 @@ bool PadRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> PadCompute(const Attrs& attrs,
-                         const Array<Tensor>& inputs,
+Array<top::Tensor> PadCompute(const Attrs& attrs,
+                         const Array<top::Tensor>& inputs,
                          const Type& out_type,
                          const Target& target) {
   const auto* param = attrs.as<PadAttrs>();
@@ -180,7 +180,7 @@ Array<Tensor> PadCompute(const Attrs& attrs,
     pad_after.push_back(pad_width[i][1]);
   }
   const auto* out_ttype = out_type.as<TensorTypeNode>();
-  return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after,
+  return Array<top::Tensor>{ topi::pad(inputs[0], pad_before, pad_after,
                                   tvm::make_const(out_ttype->dtype, param->pad_value),
                                   "T_pad",
                                   topi::kElementWise,
index 660522590bbf2222578005e983d7078d74c163d5..65fd09d93d4dcaeefc1a7018452c076269201b09 100644 (file)
@@ -166,8 +166,8 @@ bool Pool2DRel(const Array<Type>& types,
 }
 
 template<typename AttrType, topi::nn::PoolType mode>
-Array<Tensor> Pool2DCompute(const Attrs& attrs,
-                            const Array<Tensor>& inputs,
+Array<top::Tensor> Pool2DCompute(const Attrs& attrs,
+                            const Array<top::Tensor>& inputs,
                             const Type& out_type,
                             const Target& target) {
   static const Layout kNCHW("NCHW");
@@ -203,11 +203,11 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
   }
   if (mode == topi::nn::kAvgPool) {
     bool count_include_pad = reinterpret_cast<const AvgPool2DAttrs*>(param)->count_include_pad;
-    return Array<Tensor>{
+    return Array<top::Tensor>{
       topi::nn::pool(inputs[0], pool_size, strides, padding,
                      mode, ceil_mode, layout.name(), count_include_pad)};
   } else {
-    return Array<Tensor>{
+    return Array<top::Tensor>{
       topi::nn::pool(inputs[0], pool_size, strides, padding,
                      mode, ceil_mode, layout.name())};
   }
@@ -333,8 +333,8 @@ bool GlobalPool2DRel(const Array<Type>& types,
 
 
 template<topi::nn::PoolType mode>
-Array<Tensor> GlobalPool2DCompute(const Attrs& attrs,
-                                  const Array<Tensor>& inputs,
+Array<top::Tensor> GlobalPool2DCompute(const Attrs& attrs,
+                                  const Array<top::Tensor>& inputs,
                                   const Type& out_type,
                                   const Target& target) {
   static const Layout kNCHW("NCHW");
@@ -351,7 +351,7 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs,
   CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
     << "Pool2D only support 4-D input (e.g., NCHW)"
     << " or 5-D input (last dimension is a split of channel)";
-  return Array<Tensor>{
+  return Array<top::Tensor>{
     topi::nn::global_pool(inputs[0], mode, layout.name()) };
 }
 
@@ -467,8 +467,8 @@ bool AdaptivePool2DRel(const Array<Type>& types,
 }
 
 template<topi::nn::PoolType mode>
-Array<Tensor> AdaptivePool2DCompute(const Attrs& attrs,
-                                    const Array<Tensor>& inputs,
+Array<top::Tensor> AdaptivePool2DCompute(const Attrs& attrs,
+                                    const Array<top::Tensor>& inputs,
                                     const Type& out_type,
                                     const Target& target) {
   static const Layout kNCHW("NCHW");
@@ -500,7 +500,7 @@ Array<Tensor> AdaptivePool2DCompute(const Attrs& attrs,
     output_height = output_size[0];
     output_width = output_size[1];
   }
-  return Array<Tensor>{
+  return Array<top::Tensor>{
     topi::nn::adaptive_pool(inputs[0], Array<IndexExpr>{ output_height, output_width },
                             mode, layout.name()) };
 }
@@ -596,7 +596,7 @@ bool Pool2DGradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 }
 
 template <typename AttrType, topi::nn::PoolType mode>
-Array<Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<Tensor>& inputs,
+Array<top::Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<top::Tensor>& inputs,
                                 const Type& out_type, const Target& target) {
   static const Layout kNCHW("NCHW");
   const auto* param = attrs.as<AttrType>();
@@ -633,10 +633,10 @@ Array<Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<Tensor>& inputs,
   }
   if (mode == topi::nn::kAvgPool) {
     bool count_include_pad = reinterpret_cast<const AvgPool2DAttrs*>(param)->count_include_pad;
-    return Array<Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
+    return Array<top::Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
         mode, ceil_mode, layout.name(), count_include_pad)};
   } else {
-    return Array<Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
+    return Array<top::Tensor>{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding,
         mode, ceil_mode, layout.name())};
   }
 }
@@ -798,8 +798,8 @@ bool Pool1DRel(const Array<Type>& types,
 
 
 template<typename AttrType, topi::nn::PoolType mode>
-Array<Tensor> Pool1DCompute(const Attrs& attrs,
-                            const Array<Tensor>& inputs,
+Array<top::Tensor> Pool1DCompute(const Attrs& attrs,
+                            const Array<top::Tensor>& inputs,
                             const Type& out_type,
                             const Target& target) {
   static const Layout kNCW("NCW");
@@ -825,11 +825,11 @@ Array<Tensor> Pool1DCompute(const Attrs& attrs,
 
   if (mode == topi::nn::kAvgPool) {
     bool count_include_pad = reinterpret_cast<const AvgPool1DAttrs*>(param)->count_include_pad;
-    return Array<Tensor>{
+    return Array<top::Tensor>{
       topi::nn::pool1d(inputs[0], pool_size, strides, padding,
                        mode, ceil_mode, layout.name(), count_include_pad)};
   } else {
-    return Array<Tensor>{
+    return Array<top::Tensor>{
       topi::nn::pool1d(inputs[0], pool_size, strides, padding,
                        mode, ceil_mode, layout.name())};
   }
@@ -993,8 +993,8 @@ bool Pool3DRel(const Array<Type>& types,
 
 
 template<typename AttrType, topi::nn::PoolType mode>
-Array<Tensor> Pool3DCompute(const Attrs& attrs,
-                            const Array<Tensor>& inputs,
+Array<top::Tensor> Pool3DCompute(const Attrs& attrs,
+                            const Array<top::Tensor>& inputs,
                             const Type& out_type,
                             const Target& target) {
   static const Layout kNCDHW("NCDHW");
@@ -1033,11 +1033,11 @@ Array<Tensor> Pool3DCompute(const Attrs& attrs,
   }
   if (mode == topi::nn::kAvgPool) {
     bool count_include_pad = reinterpret_cast<const AvgPool3DAttrs*>(param)->count_include_pad;
-    return Array<Tensor>{
+    return Array<top::Tensor>{
       topi::nn::pool3d(inputs[0], pool_size, strides, padding,
                        mode, ceil_mode, layout.name(), count_include_pad)};
   } else {
-    return Array<Tensor>{
+    return Array<top::Tensor>{
       topi::nn::pool3d(inputs[0], pool_size, strides, padding,
                        mode, ceil_mode, layout.name())};
   }
index f2f899657bc0b6150023c3b030a53ab81a2e76bf..00ebddb7e79c1c402093ae4f0e52e2b2fd0f31b6 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -32,9 +32,9 @@ namespace relay {
 
 #define RELAY_BINARY_COMPUTE(FTOPI)                        \
   [] (const Attrs& attrs,                                  \
-      const Array<Tensor>& inputs,                         \
+      const Array<top::Tensor>& inputs,                         \
       const Type& out_type,                                \
-      const Target& target) -> Array<Tensor> {             \
+      const Target& target) -> Array<top::Tensor> {             \
     CHECK_EQ(inputs.size(), 2U);                           \
     return {FTOPI(inputs[0], inputs[1])};                  \
   }                                                        \
index dde3ef2f0420e98d853e85b8852d2a3421813c34..e20b7cfc52c1a950367cbf113f3aca0dadb0b7f6 100644 (file)
@@ -173,8 +173,8 @@ Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
 }
 
 template<typename F>
-Array<Tensor> ReduceCompute(const Attrs& attrs,
-                            const Array<Tensor>& inputs,
+Array<top::Tensor> ReduceCompute(const Attrs& attrs,
+                            const Array<top::Tensor>& inputs,
                             const Type& out_type,
                             const Target& target,
                             F f) {
@@ -320,8 +320,8 @@ bool ReduceRel(const Array<Type>& types,
   .add_argument("data", "Tensor", "The input tensor.")
 
 
-Array<Tensor> ArgMaxCompute(const Attrs& attrs,
-                            const Array<Tensor>& inputs,
+Array<top::Tensor> ArgMaxCompute(const Attrs& attrs,
+                            const Array<top::Tensor>& inputs,
                             const Type& out_type,
                             const Target& target) {
   return ReduceCompute(attrs, inputs, out_type, target, topi::argmax);
@@ -340,8 +340,8 @@ values over a given axis.
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
 
-Array<Tensor> ArgMinCompute(const Attrs& attrs,
-                            const Array<Tensor>& inputs,
+Array<top::Tensor> ArgMinCompute(const Attrs& attrs,
+                            const Array<top::Tensor>& inputs,
                             const Type& out_type,
                             const Target& target) {
   return ReduceCompute(attrs, inputs, out_type, target, topi::argmin);
@@ -358,8 +358,8 @@ values over a given axis.
 .set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
-Array<Tensor> SumCompute(const Attrs& attrs,
-                         const Array<Tensor>& inputs,
+Array<top::Tensor> SumCompute(const Attrs& attrs,
+                         const Array<top::Tensor>& inputs,
                          const Type& out_type,
                          const Target& target) {
   return ReduceCompute(attrs, inputs, out_type, target, topi::sum);
@@ -392,8 +392,8 @@ Example::
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
 
-Array<Tensor> AllCompute(const Attrs& attrs,
-                         const Array<Tensor>& inputs,
+Array<top::Tensor> AllCompute(const Attrs& attrs,
+                         const Array<top::Tensor>& inputs,
                          const Type& out_type,
                          const Target& target) {
   return ReduceCompute(attrs, inputs, out_type, target, topi::all);
@@ -429,8 +429,8 @@ Example::
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
 
-Array<Tensor> AnyCompute(const Attrs& attrs,
-                         const Array<Tensor>& inputs,
+Array<top::Tensor> AnyCompute(const Attrs& attrs,
+                         const Array<top::Tensor>& inputs,
                          const Type& out_type,
                          const Target& target) {
   return ReduceCompute(attrs, inputs, out_type, target, topi::any);
@@ -466,8 +466,8 @@ Example::
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
 
-Array<Tensor> MaxCompute(const Attrs& attrs,
-                         const Array<Tensor>& inputs,
+Array<top::Tensor> MaxCompute(const Attrs& attrs,
+                         const Array<top::Tensor>& inputs,
                          const Type& out_type,
                          const Target& target) {
   return ReduceCompute(attrs, inputs, out_type, target, topi::max);
@@ -484,8 +484,8 @@ RELAY_REGISTER_REDUCE_OP("max")
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
 
-Array<Tensor> MinCompute(const Attrs& attrs,
-                         const Array<Tensor>& inputs,
+Array<top::Tensor> MinCompute(const Attrs& attrs,
+                         const Array<top::Tensor>& inputs,
                          const Type& out_type,
                          const Target& target) {
   return ReduceCompute(attrs, inputs, out_type, target, topi::min);
@@ -503,8 +503,8 @@ RELAY_REGISTER_REDUCE_OP("min")
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
 
-Array<Tensor> ProdCompute(const Attrs& attrs,
-                          const Array<Tensor>& inputs,
+Array<top::Tensor> ProdCompute(const Attrs& attrs,
+                          const Array<top::Tensor>& inputs,
                           const Type& out_type,
                           const Target& target) {
   return ReduceCompute(attrs, inputs, out_type, target, topi::prod);
@@ -533,8 +533,8 @@ Example::
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
 
-Array<Tensor> MeanCompute(const Attrs& attrs,
-                          const Array<Tensor>& inputs,
+Array<top::Tensor> MeanCompute(const Attrs& attrs,
+                          const Array<top::Tensor>& inputs,
                           const Type& out_type,
                           const Target& target) {
   IndexExpr count = make_const(inputs[0]->dtype, 1);
@@ -598,8 +598,8 @@ bool VarianceRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> VarianceCompute(const Attrs& attrs,
-                              const Array<Tensor>& inputs,
+Array<top::Tensor> VarianceCompute(const Attrs& attrs,
+                              const Array<top::Tensor>& inputs,
                               const Type& out_type,
                               const Target& target) {
   IndexExpr count = make_const(inputs[0]->dtype, 1);
index 59bb432e756d7bafb0fa903f1b725a8826641ef4..c2af56d68d46c79e29fb9aa25c7415872d980b06 100644 (file)
@@ -65,8 +65,8 @@ bool CastRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> CastCompute(const Attrs& attrs,
-                          const Array<Tensor>& inputs,
+Array<top::Tensor> CastCompute(const Attrs& attrs,
+                          const Array<top::Tensor>& inputs,
                           const Type& out_type,
                           const Target& target) {
   const CastAttrs *param = attrs.as<CastAttrs>();
@@ -125,8 +125,8 @@ bool CastLikeRel(const Array<Type>& types,
 }
 
 
-Array<Tensor> CastLikeCompute(const Attrs& attrs,
-                              const Array<Tensor>& inputs,
+Array<top::Tensor> CastLikeCompute(const Attrs& attrs,
+                              const Array<top::Tensor>& inputs,
                               const Type& out_type,
                               const Target& target) {
   return { topi::cast(inputs[0], inputs[1]->dtype) };
@@ -156,7 +156,7 @@ RELAY_REGISTER_OP("cast_like")
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
 
 
-Array<Tensor> ReinterpretCompute(const Attrs& attrs, const Array<Tensor>& inputs,
+Array<top::Tensor> ReinterpretCompute(const Attrs& attrs, const Array<top::Tensor>& inputs,
                                  const Type& out_type, const Target& target) {
   const CastAttrs* param = attrs.as<CastAttrs>();
   CHECK(param != nullptr);
@@ -230,8 +230,8 @@ bool ExpandDimsRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> ExpandDimsCompute(const Attrs& attrs,
-                                const Array<Tensor>& inputs,
+Array<top::Tensor> ExpandDimsCompute(const Attrs& attrs,
+                                const Array<top::Tensor>& inputs,
                                 const Type& out_type,
                                 const Target& target) {
   const ExpandDimsAttrs *param = attrs.as<ExpandDimsAttrs>();
@@ -269,8 +269,8 @@ RELAY_REGISTER_OP("expand_dims")
 // relay.concatenate
 TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
 
-Array<Tensor> ConcatenateCompute(const Attrs& attrs,
-                          const Array<Tensor>& inputs,
+Array<top::Tensor> ConcatenateCompute(const Attrs& attrs,
+                          const Array<top::Tensor>& inputs,
                           const Type& out_type,
                           const Target& target) {
   const ConcatenateAttrs *param = attrs.as<ConcatenateAttrs>();
@@ -412,8 +412,8 @@ bool StackRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> StackCompute(const Attrs& attrs,
-                           const Array<Tensor>& inputs,
+Array<top::Tensor> StackCompute(const Attrs& attrs,
+                           const Array<top::Tensor>& inputs,
                            const Type& out_type,
                            const Target& target) {
   const StackAttrs *param = attrs.as<StackAttrs>();
@@ -504,13 +504,13 @@ bool TransposeRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> TransposeCompute(const Attrs& attrs,
-                               const Array<Tensor>& inputs,
+Array<top::Tensor> TransposeCompute(const Attrs& attrs,
+                               const Array<top::Tensor>& inputs,
                                const Type& out_type,
                                const Target& target) {
   const auto* param = attrs.as<TransposeAttrs>();
   CHECK(param != nullptr);
-  return Array<Tensor>{ topi::transpose(inputs[0], param->axes) };
+  return Array<top::Tensor>{ topi::transpose(inputs[0], param->axes) };
 }
 
 Expr MakeTranspose(Expr data,
@@ -687,8 +687,8 @@ bool ReshapeRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> ReshapeCompute(const Attrs& attrs,
-                             const Array<Tensor>& inputs,
+Array<top::Tensor> ReshapeCompute(const Attrs& attrs,
+                             const Array<top::Tensor>& inputs,
                              const Type& out_type,
                              const Target& target) {
   const auto* out_ttype = out_type.as<TensorTypeNode>();
@@ -922,16 +922,16 @@ bool TakeRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> TakeCompute(const Attrs& attrs,
-                          const Array<Tensor>& inputs,
+Array<top::Tensor> TakeCompute(const Attrs& attrs,
+                          const Array<top::Tensor>& inputs,
                           const Type& out_type,
                           const Target& target) {
   const auto* param = attrs.as<TakeAttrs>();
   CHECK(param != nullptr);
   if (!param->axis.defined()) {
-    return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
+    return Array<top::Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
   } else {
-    return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
+    return Array<top::Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
   }
 }
 
@@ -1009,8 +1009,8 @@ bool FullRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> FullCompute(const Attrs& attrs,
-                          const Array<Tensor>& inputs,
+Array<top::Tensor> FullCompute(const Attrs& attrs,
+                          const Array<top::Tensor>& inputs,
                           const Type& out_type,
                           const Target& target) {
   const auto* out_ttype = out_type.as<TensorTypeNode>();
@@ -1117,8 +1117,8 @@ bool FullLikeRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> FullLikeCompute(const Attrs& attrs,
-                              const Array<Tensor>& inputs,
+Array<top::Tensor> FullLikeCompute(const Attrs& attrs,
+                              const Array<top::Tensor>& inputs,
                               const Type& out_type,
                               const Target& target) {
   return { topi::full_like(inputs[0], inputs[1]()) };
@@ -1217,26 +1217,26 @@ bool ArangeRel(const Array<Type>& types,
   }
 }
 
-inline Tensor DynamicArange(const tvm::Tensor& start,
-                            const tvm::Tensor& stop,
-                            const tvm::Tensor& step,
-                            tvm::DataType dtype,
-                            std::string name = "tensor",
-                            std::string tag = topi::kInjective) {
+inline top::Tensor DynamicArange(const top::Tensor& start,
+                                 const top::Tensor& stop,
+                                 const top::Tensor& step,
+                                 tvm::DataType dtype,
+                                 std::string name = "tensor",
+                                 std::string tag = topi::kInjective) {
   tvm::PrimExpr num_elem = tvm::Var("num_elem");
-  return tvm::compute({num_elem}, [&](const Array<tvm::Var>& indices) {
+  return top::compute({num_elem}, [&](const Array<tvm::Var>& indices) {
     return tvm::cast(dtype, start[0] + step[0] * indices[0]);
   }, name, tag);
 }
 
-Array<Tensor> ArangeCompute(const Attrs& attrs,
-                            const Array<Tensor>& inputs,
-                            const Type& out_type,
-                            const Target& target) {
+Array<top::Tensor> ArangeCompute(const Attrs& attrs,
+                                 const Array<top::Tensor>& inputs,
+                                 const Type& out_type,
+                                 const Target& target) {
   const ArangeAttrs* param = attrs.as<ArangeAttrs>();
-  Tensor start = inputs[0];
-  Tensor stop =  inputs[1];
-  Tensor step = inputs[2];
+  top::Tensor start = inputs[0];
+  top::Tensor stop =  inputs[1];
+  top::Tensor step = inputs[2];
   Array<tvm::PrimExpr> empty = {0};
   return { DynamicArange(start, stop, step, param->dtype) };
 }
@@ -1324,8 +1324,8 @@ bool RepeatRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> RepeatCompute(const Attrs& attrs,
-                            const Array<Tensor>& inputs,
+Array<top::Tensor> RepeatCompute(const Attrs& attrs,
+                            const Array<top::Tensor>& inputs,
                             const Type& out_type,
                             const Target& target) {
   const RepeatAttrs *param = attrs.as<RepeatAttrs>();
@@ -1435,8 +1435,8 @@ bool TileRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> TileCompute(const Attrs& attrs,
-                          const Array<Tensor>& inputs,
+Array<top::Tensor> TileCompute(const Attrs& attrs,
+                          const Array<top::Tensor>& inputs,
                           const Type& out_type,
                           const Target& target) {
   const TileAttrs *param = attrs.as<TileAttrs>();
@@ -1496,8 +1496,8 @@ bool ReverseRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> ReverseCompute(const Attrs& attrs,
-                             const Array<Tensor>& inputs,
+Array<top::Tensor> ReverseCompute(const Attrs& attrs,
+                             const Array<top::Tensor>& inputs,
                              const Type& out_type,
                              const Target& target) {
   const ReverseAttrs *param = attrs.as<ReverseAttrs>();
@@ -1570,8 +1570,8 @@ Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) {
   return CallNode::make(op, {condition, x, y});
 }
 
-Array<Tensor> WhereCompute(const Attrs& attrs,
-                           const Array<Tensor>& inputs,
+Array<top::Tensor> WhereCompute(const Attrs& attrs,
+                           const Array<top::Tensor>& inputs,
                            const Type& out_type,
                            const Target& target) {
   return { topi::where(inputs[0], inputs[1], inputs[2]) };
@@ -1687,8 +1687,8 @@ bool SqueezeRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> SqueezeCompute(const Attrs& attrs,
-                             const Array<Tensor>& inputs,
+Array<top::Tensor> SqueezeCompute(const Attrs& attrs,
+                             const Array<top::Tensor>& inputs,
                              const Type& out_type,
                              const Target& target) {
   const SqueezeAttrs *param = attrs.as<SqueezeAttrs>();
@@ -1728,8 +1728,8 @@ Expr MakeCollapseSumLike(Expr data,
   return CallNode::make(op, {data, collapse_type}, Attrs(), {});
 }
 
-Array<Tensor> CollapseSumLikeCompute(const Attrs& attrs,
-                                     const Array<Tensor>& inputs,
+Array<top::Tensor> CollapseSumLikeCompute(const Attrs& attrs,
+                                     const Array<top::Tensor>& inputs,
                                      const Type& out_type,
                                      const Target& target) {
   const auto* out_ttype = out_type.as<TensorTypeNode>();
@@ -1773,8 +1773,8 @@ Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) {
   return CallNode::make(op, {data}, Attrs(attrs), {});
 }
 
-Array<Tensor> BroadCastToCompute(const Attrs& attrs,
-                                 const Array<Tensor>& inputs,
+Array<top::Tensor> BroadCastToCompute(const Attrs& attrs,
+                                 const Array<top::Tensor>& inputs,
                                  const Type& out_type,
                                  const Target& target) {
   auto ioattrs = attrs.as<InitOpAttrs>();
@@ -1811,8 +1811,8 @@ Expr MakeBroadCastToLike(Expr data,
   return CallNode::make(op, {data, broadcast_type}, Attrs(), {});
 }
 
-Array<Tensor> BroadCastToLikeCompute(const Attrs& attrs,
-                                     const Array<Tensor>& inputs,
+Array<top::Tensor> BroadCastToLikeCompute(const Attrs& attrs,
+                                     const Array<top::Tensor>& inputs,
                                      const Type& out_type,
                                      const Target& target) {
   const auto* out_ttype = out_type.as<TensorTypeNode>();
@@ -2018,13 +2018,13 @@ Expr MakeStridedSlice(Expr data,
   return CallNode::make(op, {data}, Attrs(attrs), {});
 }
 
-Array<Tensor> StridedSliceCompute(const Attrs& attrs,
-                                  const Array<Tensor>& inputs,
+Array<top::Tensor> StridedSliceCompute(const Attrs& attrs,
+                                  const Array<top::Tensor>& inputs,
                                   const Type& out_type,
                                   const Target& target) {
   const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
   CHECK(param != nullptr);
-  return Array<Tensor>{
+  return Array<top::Tensor>{
     topi::strided_slice(inputs[0], param->begin, param->end, param->strides)
   };
 }
@@ -2175,8 +2175,8 @@ bool SplitRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> SplitCompute(const Attrs& attrs,
-                           const Array<Tensor>& inputs,
+Array<top::Tensor> SplitCompute(const Attrs& attrs,
+                           const Array<top::Tensor>& inputs,
                            const Type& out_type,
                            const Target& target) {
   const auto param = attrs.as<SplitAttrs>();
@@ -2184,11 +2184,11 @@ Array<Tensor> SplitCompute(const Attrs& attrs,
 
   if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
     int64_t num_sections = sections->value;
-    return Array<Tensor>{
+    return Array<top::Tensor>{
       topi::split_sections(inputs[0], num_sections, param->axis) };
   } else {
     auto indices = Downcast<Array<Integer> >(param->indices_or_sections);
-    return Array<Tensor>{ topi::split(inputs[0], indices, param->axis) };
+    return Array<top::Tensor>{ topi::split(inputs[0], indices, param->axis) };
   }
 }
 
@@ -2304,8 +2304,8 @@ Expr MakeSliceLike(Expr data,
   return CallNode::make(op, {data, shape_like}, Attrs(attrs), {});
 }
 
-Array<Tensor> SliceLikeCompute(const Attrs& attrs,
-                               const Array<Tensor>& inputs,
+Array<top::Tensor> SliceLikeCompute(const Attrs& attrs,
+                               const Array<top::Tensor>& inputs,
                                const Type& out_type,
                                const Target& target) {
   const auto* param = attrs.as<SliceLikeAttrs>();
@@ -2342,7 +2342,7 @@ Array<Tensor> SliceLikeCompute(const Attrs& attrs,
         << topi::GetConstInt(src_shape[axis]);
     }
   }
-  return Array<Tensor>{
+  return Array<top::Tensor>{
     topi::strided_slice(inputs[0],
                         GetIntArray(begin_idx),
                         GetIntArray(end_idx),
@@ -2370,13 +2370,13 @@ RELAY_REGISTER_OP("slice_like")
 // relay.layout_transform
 TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
 
-Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
-                                     const Array<Tensor>& inputs,
+Array<top::Tensor> LayoutTransformCompute(const Attrs& attrs,
+                                     const Array<top::Tensor>& inputs,
                                      const Type& out_type,
                                      const Target& target) {
   const auto* param = attrs.as<LayoutTransformAttrs>();
   CHECK(param != nullptr);
-  return Array<Tensor>{
+  return Array<top::Tensor>{
     topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)
   };
 }
@@ -2503,8 +2503,8 @@ bool GatherNDRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> GatherNDCompute(const Attrs& attrs,
-                              const Array<Tensor>& inputs,
+Array<top::Tensor> GatherNDCompute(const Attrs& attrs,
+                              const Array<top::Tensor>& inputs,
                               const Type& out_type,
                               const Target& target) {
   return { topi::gather_nd(inputs[0], inputs[1]) };
@@ -2557,13 +2557,14 @@ bool SequenceMaskRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> SequenceMaskCompute(const Attrs& attrs,
-                                  const Array<Tensor>& inputs,
+Array<top::Tensor> SequenceMaskCompute(const Attrs& attrs,
+                                  const Array<top::Tensor>& inputs,
                                   const Type& out_type,
                                   const Target& target) {
   const auto* param = attrs.as<SequenceMaskAttrs>();
   CHECK(param != nullptr);
-  return Array<Tensor>{ topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) };
+  return Array<top::Tensor>{
+    topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) };
 }
 
 Expr MakeSequenceMask(Expr data,
@@ -2669,13 +2670,13 @@ bool OneHotRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> OneHotCompute(const Attrs& attrs,
-                            const Array<Tensor>& inputs,
+Array<top::Tensor> OneHotCompute(const Attrs& attrs,
+                            const Array<top::Tensor>& inputs,
                             const Type& out_type,
                             const Target& target) {
   const auto* param = attrs.as<OneHotAttrs>();
   CHECK(param != nullptr);
-  return Array<Tensor> {
+  return Array<top::Tensor> {
     topi::one_hot(inputs[0],
                   inputs[1](),
                   inputs[2](),
index 5189f7d97ec33f43c495dcac1076a593389697ac..331653b494454499b1f47a3376f1673c501cc9d5 100644 (file)
@@ -34,9 +34,9 @@ namespace relay {
 
 #define RELAY_UNARY_COMPUTE(FTOPI)                      \
   [] (const Attrs& attrs,                               \
-      const Array<Tensor>& inputs,                      \
+      const Array<top::Tensor>& inputs,                      \
       const Type& out_type,                             \
-      const Target& target) -> Array<Tensor> {          \
+      const Target& target) -> Array<top::Tensor> {          \
     return {FTOPI(inputs[0])};                          \
   }                                                     \
 
@@ -290,8 +290,8 @@ bool ShapeOfRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> ShapeOfCompute(const Attrs& attrs,
-                             const Array<Tensor>& inputs,
+Array<top::Tensor> ShapeOfCompute(const Attrs& attrs,
+                             const Array<top::Tensor>& inputs,
                              const Type& out_type,
                              const Target& target) {
   CHECK_EQ(inputs.size(), 1);
@@ -341,14 +341,14 @@ bool NdarraySizeRel(const Array<Type>& types,
   return true;
 }
 
-Array<Tensor> NdarraySizeCompute(const Attrs& attrs,
-                          const Array<Tensor>& inputs,
+Array<top::Tensor> NdarraySizeCompute(const Attrs& attrs,
+                          const Array<top::Tensor>& inputs,
                           const Type& out_type,
                           const Target& target) {
   CHECK_EQ(inputs.size(), 1);
   const auto* param = attrs.as<NdarraySizeAttrs>();
   CHECK(param != nullptr);
-  return Array<Tensor>{topi::ndarray_size(inputs[0], param->dtype)};
+  return Array<top::Tensor>{topi::ndarray_size(inputs[0], param->dtype)};
 }
 
 TVM_REGISTER_GLOBAL("relay.op.contrib._make.ndarray_size")
index 616dc2a5510b8877dd23a8deb2ab2211e93063f7..9964a8274392b1c6cdca1f26fec6d40947aa04c7 100644 (file)
@@ -82,12 +82,12 @@ Its function is mostly shape transform.")doc" TVM_ADD_FILELINE)
 .set_attrs_type<YoloReorgAttrs>()
 .add_type_rel("YoloReorg", YoloReorgRel)
 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
-                                         const Array<Tensor>& inputs,
+                                         const Array<top::Tensor>& inputs,
                                          const Type& out_type,
                                          const Target& target) {
   const auto* params = attrs.as<YoloReorgAttrs>();
   CHECK(params != nullptr);
-  return Array<Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
+  return Array<top::Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
 });
 
 }  // namespace relay
index b027e5edb571ba4c02f734a6f1213ab41d1c2c79..9746182099316c3f66fed2ed6a9835be26e10090 100644 (file)
@@ -28,7 +28,7 @@
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/transform.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tuple>
 #include <vector>
 #include <functional>
@@ -78,10 +78,10 @@ class AlterTransformMemorizer : public TransformMemorizer {
     Expr new_e;
     bool modified = false;
     if (falter_layout.count(op)) {
-      tvm::Array<tvm::Tensor> tinfos;
+      tvm::Array<tvm::top::Tensor> tinfos;
       for (auto expr : ref_call->args) {
         auto ttype = expr->type_as<TensorTypeNode>();
-        tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype));
+        tinfos.push_back(tvm::top::placeholder(ttype->shape, ttype->dtype));
       }
       Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos);
       if (altered_value.defined()) {
index 20007d289327d5eca65193fc64109f8bffe2bdee..d435efd249d3f44ad922b2d7f17ff77481850c98 100644 (file)
@@ -28,7 +28,7 @@
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/transform.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tuple>
 #include <vector>
 #include <functional>
@@ -86,10 +86,10 @@ class ConvertTransformMemorizer : public TransformMemorizer {
     Expr new_e;
     bool modified = false;
     if (fconvert_layout.count(op)) {
-      tvm::Array<tvm::Tensor> tinfos;
+      tvm::Array<tvm::top::Tensor> tinfos;
       for (auto expr : ref_call->args) {
         auto ttype = expr->type_as<TensorTypeNode>();
-        tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype));
+        tinfos.push_back(tvm::top::placeholder(ttype->shape, ttype->dtype));
       }
       Expr altered_value =
           fconvert_layout[op](ref_call->attrs, new_args, tinfos, operator->()->desired_layout_);
index e236de72be545e0bdd13126b139d7fc6d1f2a938..8d89c0a53d3ae65f89e706932281b1bc8da5709f 100644 (file)
@@ -23,7 +23,7 @@
  */
 
 #include <tvm/lowered_func.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/transform.h>
index 63608084e3f8ff406cf12e32688713023075b21f..12e72cf1fad2160af174d4dcb1a4ebb6d3cdaf6d 100644 (file)
@@ -23,7 +23,7 @@
  * shape, dtype or layout to another op or a sequence of ops.
  */
 
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/transform.h>
diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc
deleted file mode 100644 (file)
index 3e32923..0000000
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file auto_inline_elem_wise.cc
- */
-#include <tvm/schedule_pass.h>
-#include <tvm/operation.h>
-#include <tvm/ir_functor_ext.h>
-
-namespace tvm {
-namespace schedule {
-
-using namespace ir;
-
-class ElemWiseDetector : public ir::ExprVisitor {
- public:
-  explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
-
-  void VisitExpr(const PrimExpr& e) final {
-    if (!is_elem_wise_) return;
-    ExprVisitor::VisitExpr(e);
-  }
-
-  void VisitExpr_(const CallNode* op) final {
-    Array<PrimExpr> axis = op->args;
-    if (axis_.size() != axis.size()) {
-      is_elem_wise_ = false;
-      return;
-    }
-
-    for (size_t i = 0; i < axis_.size(); ++i) {
-      if (!axis[i].same_as(axis_[i]->var)) {
-        is_elem_wise_ = false;
-        return;
-      }
-    }
-    ExprVisitor::VisitExpr_(op);
-  }
-
-  bool is_elem_wise_{true};
-
- private:
-  Array<IterVar> axis_;
-};
-
-
-bool IsElemWise(const Operation& op) {
-  if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
-    ElemWiseDetector v = ElemWiseDetector(compute->axis);
-    for (auto& e : compute->body) v(e);
-    return v.is_elem_wise_;
-  }
-  return false;
-}
-
-void AutoInlineElemWise(Schedule sch) {
-  for (Stage s : sch->stages) {
-    if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) {
-      s.compute_inline();
-    }
-  }
-}
-
-bool IsBroadcast(const Operation& op) {
-  if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
-    if (compute->reduce_axis.size()) {
-      return false;
-    }
-    // TODO(nicolasvasilache): Implement Me
-  }
-  return false;
-}
-
-void AutoInlineBroadcast(Schedule sch) {
-  for (Stage s : sch->stages) {
-    if (!s.is_scheduled() && IsBroadcast(s->op) && !s->is_output) {
-      s.compute_inline();
-    }
-  }
-}
-
-bool IsInjective(const Operation& op) {
-  if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
-    return compute->reduce_axis.size() == 0;
-  }
-  return false;
-}
-
-void AutoInlineInjective(Schedule sch) {
-  for (Stage s : sch->stages) {
-    if (!s.is_scheduled() && IsInjective(s->op) && !s->is_output) {
-      s.compute_inline();
-    }
-  }
-}
-
-}  // namespace schedule
-}  // namespace tvm
diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc
deleted file mode 100644 (file)
index 5f363db..0000000
+++ /dev/null
@@ -1,263 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file bound.cc
- * \brief The bound inference logic.
- */
-#include <tvm/schedule_pass.h>
-#include <tvm/operation.h>
-#include <tvm/ir_pass.h>
-#include <unordered_map>
-#include <unordered_set>
-#include "graph.h"
-#include "message_passing.h"
-#include "../runtime/thread_storage_scope.h"
-
-namespace tvm {
-namespace schedule {
-
-using runtime::StorageRank;
-using runtime::StorageScope;
-using runtime::ThreadScope;
-
-/*! \brief The graph context used during bound inference. */
-struct GraphContext {
-  /*! \brief The feed graph */
-  FeedGraph feed_graph;
-  /*! \brief Attachment path */
-  AttachPath attach_path;
-  /*! \brief The bind map */
-  std::unordered_map<IterVar, IterVar> bind_map;
-  /*! \brief map from op to stage */
-  std::unordered_map<const Object*, Stage> op2stage_;
-};
-
-bool NeedRelax(const IterVar& iv,
-               bool found_attach,
-               const std::unordered_map<IterVar, IterVar>& bind_map,
-               const runtime::StorageScope& scope) {
-  auto it = bind_map.find(iv);
-  const std::string& tag = (
-      it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
-  if (tag.length() == 0 || tag == "pipeline") {
-    return !found_attach;
-  }
-  ThreadScope ts = ThreadScope::make(tag);
-
-  // When there is warp memory
-  // threadIdx.x must be set to be warp index.
-  if (scope.rank == StorageRank::kWarp &&
-      ts.rank == 1 &&
-      ts.dim_index == 0) {
-    return true;
-  }
-  return static_cast<int>(scope.rank) <= ts.rank;
-}
-
-// infer storage scope, if not given
-StorageScope InferStorageScope(
-    const Stage& stage, const GraphContext& ctx) {
-  if (stage->scope.length() != 0) {
-    return StorageScope::make(stage->scope);
-  }
-  int max_rank = -1;
-  for (IterVar iv : ctx.attach_path.at(stage->op)) {
-    auto it = ctx.bind_map.find(iv);
-    const std::string& tag = (
-        it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
-    if (tag != "pipeline" && tag.length() != 0) {
-      max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
-    }
-  }
-  StorageScope s;
-  s.rank = runtime::DefaultStorageRank(max_rank);
-  return s;
-}
-
-
-void InferRootBound(const Stage& stage,
-                    const GraphContext& ctx,
-                    std::unordered_map<IterVar, Range>* rmap) {
-  CHECK_NE(stage->attach_type, kInline)
-      << "call schedule.normalize before scheduleops";
-  if (stage->attach_type == kInlinedAlready) return;
-  if (stage->is_output) {
-    // verify correctness.
-    CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
-          << "Output must be attached at root";
-  }
-  if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
-    for (auto iv :  stage->op->root_iter_vars()) {
-      CHECK(iv->dom.defined());
-      CHECK(!rmap->count(iv));
-      (*rmap)[iv] = iv->dom;
-    }
-    return;
-  }
-  // The tensor domain.
-  std::unordered_map<Tensor, TensorDom> tmap;
-  // The consumers of the op.
-  std::unordered_set<Operation> consumers;
-  for (int i = 0; i < stage->op->num_outputs(); ++i) {
-    Tensor t = stage->op.output(i);
-    tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
-    auto it = ctx.feed_graph.find(t);
-    if (it != ctx.feed_graph.end()) {
-      for (const Operation& op : it->second) {
-        consumers.insert(op);
-      }
-    } else {
-      LOG(INFO) << "not in feed graph consumer = " << stage->op;
-    }
-  }
-  // storage scope.
-  runtime::StorageScope scope = InferStorageScope(stage, ctx);
-  // Bound prop by other consumers.
-  // - Compute bound by relaxation rules: NeedRelax
-  //   - For normal index, use relative location of loop nest./
-  //   - For thread index, use the thread scope.
-  //
-  Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
-  // The parent set.
-  for (const Operation& op : consumers) {
-    std::unordered_map<const VarNode*, IntSet> relax_set;
-    std::unordered_map<IterVar, IntSet> up_state;
-    bool found_attach = false;
-    CHECK(ctx.op2stage_.count(op.get()));
-    const Stage& op_stage = ctx.op2stage_.at(op.get());
-    // Consumer nest
-    for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) {
-      IterVar iv = op_stage->leaf_iter_vars[i - 1];
-      if (stage_attach.size() != 0 && iv == stage_attach[0]) {
-        found_attach = true;
-      }
-      auto it = rmap->find(iv);
-      CHECK(it != rmap->end());
-      const Range& vrange = it->second;
-      if (is_one(vrange->extent)) {
-        up_state[iv] = IntSet::single_point(vrange->min);
-      } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
-        CHECK(is_zero(vrange->min))
-            << "InferBound requires every leaf iter var's min equals 0, "
-            << " call schedule.normalize to achieve this. ";
-        if (ctx.bind_map.count(iv)) {
-          up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
-        } else {
-          up_state[iv] = IntSet::single_point(iv->var);
-        }
-      } else {
-        up_state[iv] = IntSet::range(vrange);
-      }
-    }
-    // Consumer's attach nest
-    for (IterVar iv : ctx.attach_path.at(op)) {
-      if (stage_attach.size() != 0 && iv == stage_attach[0]) {
-        found_attach = true;
-      }
-      Range vrange = rmap->at(iv);
-      CHECK(is_zero(vrange->min))
-          << "InferBound requires every leaf iter var's min equals 0, "
-          << "call schedule.normalize to achieve this.";
-      if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
-        relax_set[iv->var.get()] = IntSet::range(vrange);
-        if (ctx.bind_map.count(iv)) {
-          relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
-        }
-      }
-    }
-    CHECK(found_attach || stage_attach.size() == 0)
-        << "Invalid Schedule, cannot find the producer " << stage->op
-        << " along the loop nest specified by compute_at of consumer " << op;
-    // Get the domain of the consumer
-    PassUpDomain(op_stage, *rmap, &up_state);
-    // Relax if needed.
-    std::unordered_map<const VarNode*, IntSet> dom_map;
-    arith::Analyzer analyzer;
-    for (auto iv : op->root_iter_vars()) {
-      Range r;
-      if (up_state.count(iv)) {
-        r = up_state.at(iv).cover_range(iv->dom);
-      } else {
-        r = iv->dom;
-      }
-      if (relax_set.size() != 0) {
-        dom_map[iv->var.get()] = EvalSet(r, relax_set);
-      } else {
-        dom_map[iv->var.get()] = IntSet::range(r);
-      }
-      analyzer.Bind(iv->var, r);
-    }
-    op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
-  }
-  stage->op->GatherBound(stage->op, tmap, rmap);
-}
-
-Map<IterVar, Range> InferBound(const Schedule& sch) {
-  // Prepare context
-  GraphContext ctx;
-  Array<Operation> roots;
-  arith::Analyzer analyzer;
-
-  for (Operation op : sch->outputs) {
-    roots.push_back(sch->stage_map[op]->op);
-  }
-  ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
-
-  for (Stage stage : sch->stages) {
-    for (auto kv : stage->iter_var_attrs) {
-      if (kv.second->bind_thread.defined()) {
-        CHECK(!ctx.bind_map.count(kv.first));
-        ctx.bind_map[kv.first] = kv.second->bind_thread;
-      }
-    }
-    ctx.op2stage_[stage->op.get()] = stage;
-  }
-  ctx.attach_path = CreateAttachPath(sch);
-  // Run inference.
-  std::unordered_map<IterVar, Range> ret;
-  for (size_t i = sch->stages.size(); i != 0; --i) {
-    const Stage& stage = sch->stages[i - 1];
-    InferRootBound(stage, ctx, &ret);
-
-    // bind bound of root iter vars.
-    for (auto iv : stage->op->root_iter_vars()) {
-      auto it = ret.find(iv);
-      if (it != ret.end()) {
-        analyzer.Bind(iv->var, it->second);
-      }
-    }
-
-    // pass down to get bound of all iter vars.
-    PassDownDomain(stage, &ret, &analyzer);
-    for (IterVar iv : stage->env_threads) {
-      CHECK(iv->dom.defined());
-      ret[iv] = iv->dom;
-    }
-  }
-  for (auto& p : ret) {
-    ret[p.first] = Range::make_by_min_extent(
-        analyzer.Simplify(p.second->min),
-        analyzer.Simplify(p.second->extent));
-  }
-  return Map<IterVar, Range>(ret.begin(), ret.end());
-}
-
-}  // namespace schedule
-}  // namespace tvm
diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc
deleted file mode 100644 (file)
index 3cf30f4..0000000
+++ /dev/null
@@ -1,433 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file graph.cc
- * \brief Utilities to get information about schedule graph.
- */
-#include <tvm/ir.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/operation.h>
-#include <utility>
-#include <unordered_set>
-#include <unordered_map>
-#include "graph.h"
-
-namespace tvm {
-namespace schedule {
-// key to specific tensor dimension.
-struct TensorDimKey {
-  ir::FunctionRef f;
-  int value_index;
-  int dim;
-  TensorDimKey() {}
-  TensorDimKey(const ir::CallNode* op, int dim)
-      : f(op->func), value_index(op->value_index), dim(dim) {
-  }
-  TensorDimKey(const Tensor& t, int dim)
-      : f(t->op), value_index(t->value_index), dim(dim) {
-  }
-  TensorDimKey(const Tensor& t, size_t dim)
-      : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {
-  }
-  inline bool operator==(const TensorDimKey& other) const {
-    return f == other.f &&
-        value_index == other.value_index &&
-        dim == other.dim;
-  }
-  inline bool operator!=(const TensorDimKey& other) const {
-    return !operator==(other);
-  }
-};
-}  // namespace schedule
-}  // namespace tvm
-
-namespace std {
-template <>
-struct hash<::tvm::schedule::TensorDimKey> {
-  std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const {
-    size_t lhs = ::tvm::ObjectHash()(k.f);
-    size_t rhs = static_cast<size_t>(k.value_index) << 16UL |
-        static_cast<size_t>(k.dim);
-    lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
-    return lhs;
-  }
-};
-}  // namespace std
-
-
-namespace tvm {
-namespace schedule {
-
-// construct a read graph that gives readers of each operation
-// that the root depend on
-ReadGraph CreateReadGraph(const Array<Operation>& roots) {
-  ReadGraph rmap;
-  std::vector<Operation> stack;
-  std::unordered_set<const Object*> visited;
-  // initialize the roots
-  for (Operation op : roots) {
-    stack.push_back(op);
-    visited.insert(op.get());
-  }
-
-  while (!stack.empty()) {
-    Operation op = stack.back();
-    stack.pop_back();
-    Array<Tensor> deps = op->InputTensors();
-    rmap.Set(op, deps);
-    for (Tensor t : deps) {
-      if (t->op.defined() && visited.count(t->op.get()) == 0) {
-        visited.insert(t->op.get());
-        stack.push_back(t->op);
-      }
-    }
-  }
-  return rmap;
-}
-
-// Do DFS visit to get the subgraph.
-// Return if op is inside the subgraph.
-bool GetSubGraphByPostDFS_(
-    const Operation& op,
-    const std::unordered_set<const Object*>& boundary,
-    bool include_bounary,
-    std::unordered_map<const Object*, bool>* visited,
-    Array<Operation>* result) {
-  if (visited->count(op.get())) {
-    return visited->at(op.get());
-  }
-  if (boundary.count(op.get())) {
-    (*visited)[op.get()] = true;
-    if (include_bounary) {
-      result->push_back(op);
-    }
-    return true;
-  }
-  // mark to avoid loop
-  // Not necessary for DAG.
-  (*visited)[op.get()] = false;
-  // check if we can reach boundary.
-  bool reach_boundary = false;
-  for (Tensor t : op->InputTensors()) {
-    if (GetSubGraphByPostDFS_(t->op, boundary,
-                              include_bounary,
-                              visited, result)) {
-      reach_boundary = true;
-    }
-  }
-  (*visited)[op.get()] = reach_boundary;
-  if (reach_boundary) {
-    result->push_back(op);
-  }
-  return reach_boundary;
-}
-
-Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
-                             const Array<Tensor>& inputs,
-                             bool include_inputs) {
-  Array<Operation> result;
-  std::unordered_set<const Object*> boundary;
-  for (Tensor t : inputs) {
-    boundary.insert(t->op.get());
-  }
-  std::unordered_map<const Object*, bool> visited;
-  for (Tensor t : outputs) {
-    GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
-                          &visited, &result);
-  }
-  return result;
-}
-
-
-void PostDFSOrder(const Operation& op,
-                  const ReadGraph& g,
-                  std::unordered_set<Operation>* visited,
-                  Array<Operation>* post_order) {
-  if (visited->count(op)) return;
-  visited->insert(op);
-  for (const auto& t : g.at(op)) {
-    PostDFSOrder(t->op, g, visited, post_order);
-  }
-  post_order->push_back(op);
-}
-
-Array<Operation> PostDFSOrder(
-    const Array<Operation>& roots,
-    const ReadGraph& g) {
-  std::unordered_set<Operation> visited;
-  Array<Operation> post_order;
-  for (Operation op : roots) {
-    PostDFSOrder(op, g, &visited, &post_order);
-  }
-  return post_order;
-}
-
-FeedGraph CreateFeedGraph(const ReadGraph& g) {
-  FeedGraph fg;
-  for (auto kv : g) {
-    for (Tensor t : kv.second) {
-      fg[t].push_back(kv.first);
-    }
-  }
-  return fg;
-}
-
-AttachPath CreateAttachPath(Schedule sch) {
-  AttachPath ret;
-  for (Stage stage : sch->stages) {
-    std::unordered_set<const Object*> visited;
-    Array<IterVar> path;
-    for (Stage s = stage; s.defined();) {
-      CHECK(!visited.count(s.get()))
-          << "Find loop in compute_at attach group";
-      visited.insert(s.get());
-      Stage spec = s.GetAttachSpec();
-      bool start_attach;
-      IterVar attach_ivar;
-      if (spec->attach_type == kScope) {
-        attach_ivar = spec->attach_ivar;
-        s = spec->attach_stage;
-        start_attach = false;
-        CHECK(attach_ivar.defined());
-      } else if (spec->attach_type == kScanUpdate) {
-        s = spec->attach_stage;
-        start_attach = true;
-      } else {
-        break;
-      }
-      CHECK(s.defined());
-      for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
-        IterVar iv = s->leaf_iter_vars[i - 1];
-        if (!start_attach && iv.same_as(attach_ivar)) {
-          start_attach = true;
-        }
-        if (start_attach) path.push_back(iv);
-      }
-      CHECK(start_attach)
-          << "Invalid Schedule: cannot find attach point " << attach_ivar
-          << " in the schedule of " << s->op;
-    }
-    if (!ret.count(stage->op)) {
-      ret.Set(stage->op, path);
-    }
-  }
-  return ret;
-}
-
-// graph of push reach relation of tensor dimensions
-using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;
-
-ReachGraph GetReachGraph(const Array<Operation>& ops) {
-  ReachGraph reach;
-  std::unordered_set<const Object*> bset;
-  for (size_t i = 0; i < ops.size(); ++i) {
-    bset.insert(ops[i].get());
-  }
-
-  for (Operation op : ops) {
-    if (const auto* scan_op = op.as<ScanOpNode>()) {
-      const auto& update = scan_op->update;
-      const auto& init = scan_op->init;
-      for (size_t i = 0; i < update.size(); ++i) {
-        Tensor t = op.output(i);
-        for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
-          reach[TensorDimKey(t, k)].emplace_back(
-              TensorDimKey(update[i], k));
-          reach[TensorDimKey(t, k)].emplace_back(
-              TensorDimKey(init[i], k));
-        }
-      }
-    } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
-      std::unordered_map<const Object*, TensorDimKey> vmap;
-      const auto& axis = compute_op->axis;
-      Tensor t = op.output(0);
-      for (size_t i = 0; i < axis.size(); ++i) {
-        vmap[axis[i]->var.get()] = TensorDimKey(t, i);
-        reach[TensorDimKey(t, i)] = {};
-      }
-      auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) {
-        const ir::CallNode *call = n.as<ir::CallNode>();
-        if (call != nullptr && call->func.defined()) {
-          if (!bset.count(call->func.get())) return;
-          for (size_t i = 0; i < call->args.size(); ++i) {
-            TensorDimKey dkey(call, static_cast<int>(i));
-            auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) {
-              const VarNode *v = node.as<VarNode>();
-              auto it = vmap.find(v);
-              if (it != vmap.end()) {
-                reach[it->second].push_back(dkey);
-              }
-            };
-            ir::PostOrderVisit(call->args[i], fpush);
-          }
-        }
-      };
-      for (auto& e : compute_op->body) {
-        ir::PostOrderVisit(e, fvisit);
-      }
-    }
-  }
-  return reach;
-}
-
-Array<Operation> ScanGetBody(const Operation& scan_op) {
-  const ScanOpNode* scan = scan_op.as<ScanOpNode>();
-  // Get the body.
-  Array<Tensor> inputs;
-  for (Tensor t : scan->state_placeholder) {
-    inputs.push_back(t);
-  }
-  for (Tensor t : scan->inputs) {
-    inputs.push_back(t);
-  }
-  return GetSubGraph(scan->update, inputs, false);
-}
-
-Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
-  const ScanOpNode* scan = scan_op.as<ScanOpNode>();
-  Array<Operation> body = ScanGetBody(scan_op);
-
-  std::unordered_map<TensorDimKey, const Object*> exact_reach;
-  std::unordered_set<const Object*> fail_set;
-
-  for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
-    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
-      TensorDimKey key(scan->state_placeholder[i], k);
-      exact_reach[key] = scan->spatial_axis_[sp_idx].get();
-    }
-  }
-  // merge exact reach
-  auto f_merge_key = [&exact_reach, &fail_set](
-      const TensorDimKey& dst, const TensorDimKey& src) {
-    auto sit = exact_reach.find(src);
-    if (sit == exact_reach.end()) return;
-    auto dit = exact_reach.find(dst);
-    if (dit == exact_reach.end()) {
-      exact_reach[dst] = sit->second;
-    } else {
-      if (dit->second != sit->second) {
-        fail_set.insert(dit->second);
-        fail_set.insert(sit->second);
-      }
-    }
-  };
-  // prop exact reach back.
-  for (size_t i = 0; i < body.size(); ++i) {
-    const Operation& op = body[i];
-    if (const auto* scan_op = op.as<ScanOpNode>()) {
-      const auto& update = scan_op->update;
-      const auto& init = scan_op->init;
-      for (size_t i = 0; i < update.size(); ++i) {
-        Tensor t = op.output(i);
-        for (size_t k = 1; k < update[i]->shape.size(); ++k) {
-          f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
-          f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
-        }
-      }
-    } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
-      std::unordered_map<const Object*, std::vector<TensorDimKey> > vmap;
-      const auto& axis = compute_op->axis;
-      for (size_t i = 0; i < axis.size(); ++i) {
-        std::vector<TensorDimKey> keys;
-        for (int j = 0; j < op->num_outputs(); ++j) {
-          keys.emplace_back(op.output(j), i);
-        }
-        vmap[axis[i]->var.get()] = std::move(keys);
-      }
-      auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
-          const ObjectRef& n) {
-        const ir::CallNode *call = n.as<ir::CallNode>();
-        if (call != nullptr && call->func.defined()) {
-          for (size_t i = 0; i < call->args.size(); ++i) {
-            auto it = vmap.find(call->args[i].get());
-            TensorDimKey src(call, static_cast<int>(i));
-            if (it != vmap.end()) {
-              const std::vector<TensorDimKey>& keys = it->second;
-              for (const auto& key : keys) {
-                f_merge_key(key, src);
-              }
-            } else {
-              if (exact_reach.count(src)) {
-                fail_set.insert(exact_reach.at(src));
-              }
-            }
-          }
-        }
-      };
-      for (auto& e : compute_op->body) {
-        ir::PostOrderVisit(e, fvisit);
-      }
-    }
-  }
-  ReachGraph reach;
-  Map<IterVar, PrimExpr> ret;
-  std::unordered_set<TensorDimKey> place_holder_ref;
-  for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
-    for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
-      place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
-    }
-  }
-
-  for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
-    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
-      TensorDimKey key(scan->update[i], k);
-      TensorDimKey target(scan->state_placeholder[i], k);
-      IterVar sp_iv = scan->spatial_axis_[sp_idx];
-      if (fail_set.count(sp_iv.get()) ||
-          !exact_reach.count(key) ||
-          exact_reach.at(key) != sp_iv.get()) {
-        ret.Set(sp_iv, make_const(DataType::Int(32), 0));
-      } else {
-        // now we proved exact match, need to prove no interference with other graph.
-        if (reach.size() == 0) reach = GetReachGraph(body);
-        // do a DFS
-        std::unordered_set<TensorDimKey> visited;
-        std::vector<TensorDimKey> stack{key};
-        visited.insert(key);
-        while (!stack.empty()) {
-          TensorDimKey k = stack.back();
-          if (k != target && place_holder_ref.count(k)) break;
-          stack.pop_back();
-          if (!reach.count(k)) {
-            LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
-          }
-
-          for (TensorDimKey kk : reach.at(k)) {
-            if (visited.count(kk)) {
-              continue;
-            }
-            visited.insert(kk);
-            stack.push_back(kk);
-          }
-        }
-        if (!stack.empty()) {
-          // failed the prove.
-          ret.Set(sp_iv, make_const(DataType::Int(32), 0));
-        } else {
-          ret.Set(sp_iv, make_const(DataType::Int(32), 1));
-        }
-      }
-    }
-  }
-  return ret;
-}
-
-}  // namespace schedule
-}  // namespace tvm
diff --git a/src/schedule/graph.h b/src/schedule/graph.h
deleted file mode 100644 (file)
index 99ba6e3..0000000
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file graph.h
- * \brief Utilities to get information about schedule graph.
- */
-#ifndef TVM_SCHEDULE_GRAPH_H_
-#define TVM_SCHEDULE_GRAPH_H_
-
-#include <tvm/expr.h>
-#include <tvm/schedule.h>
-#include <tvm/operation.h>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-
-namespace tvm {
-namespace schedule {
-
-/*!
- * \brief data structure of Operation->Tensors it reads
- */
-using ReadGraph = Map<Operation, Array<Tensor> >;
-
-/*!
- * \brief AttachPath maps op-> a list of IterVar
- */
-using AttachPath = Map<Operation, Array<IterVar> >;
-
-/*!
- * \brief The map between tensor and operation it feeds to.
- */
-using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
-
-/*!
- * \brief Get read graph of each operation to all the
- *  Tensors that it directly depends on.
- *
- *  The result map contains Operations needed to finish root Operation.
- * \param roots The root operation.
- * \return The result map.
- */
-ReadGraph CreateReadGraph(const Array<Operation>& roots);
-
-/*!
- * \brief Get minimum subgraph between outputs and inputs.
- *  The operations contains node which input-reachable from any inputs
- *  output reachable to any outputs.
- *
- *  The inputs won't be included in the subgraph, the outputs will be included.
- *
- * \param outputs The outputs of the subgraph
- * \param inputs The inputs to the subgraph.
- * \param include_inputs Whether to include inputs
- *
- * \return The subgraph.
- */
-Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
-                             const Array<Tensor>& inputs,
-                             bool include_inputs);
-
-/*!
- * \brief Get a post DFS ordered of operations in the graph.
- * \param roots The root of the graph.
- * \param g The read graph.
- * \return vector order of Operations in PostDFS order.
- *
- * \note PostDFSOrder is a special case of Topoligical order,
- *   and can be used when topoligical order is needed.
- */
-Array<Operation> PostDFSOrder(
-    const Array<Operation>& roots, const ReadGraph& g);
-
-/*!
- * \brief Create feedgraph for given Schedule
- * \param  g The read graph.
- * \return The created feedgraph.
- */
-FeedGraph CreateFeedGraph(const ReadGraph& g);
-
-/*!
- * \brief Create AttachPath that  maps op-> a list of IterVar
- *  That represents the loop nest op sits in from inner most to outermost
- *  Also inserts attach_stage for scan updates when needed.
- *
- * \param sch The schedule.
- * \return The attach path.
- */
-AttachPath CreateAttachPath(Schedule sch);
-
-/*!
- * \brief Get all operations inside the recursion of scan.
- * \param scan_op The scan node ops.
- * \return The body operations, in read dependency order.
- */
-Array<Operation> ScanGetBody(const Operation& scan_op);
-
-/*!
- * \brief Analyze each spatial dimension of scan's result.
- *  Give check on whether each dimension is fix point,
- *  An axis is a fixed point if it only refers back to itself in recursion
- *  and it is not used in axis of other recursion field.
- *
- *  next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
- *
- * \param scan The scan node.
- * \return Map of spatial_axis -> IntImm
- */
-Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan);
-
-}  // namespace schedule
-}  // namespace tvm
-
-#endif  // TVM_SCHEDULE_GRAPH_H_
diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc
deleted file mode 100644 (file)
index 816ea44..0000000
+++ /dev/null
@@ -1,543 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file message_passing.cc
- * \brief The message passing domain.
- */
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include "message_passing.h"
-#include "../arith/compute_expr.h"
-
-namespace tvm {
-namespace schedule {
-
-using namespace ir;
-
-void Update(std::unordered_map<IterVar, Range>* p_state,
-            const IterVar& iv,
-            Range r,
-            arith::Analyzer* analyzer) {
-  auto it = p_state->find(iv);
-  if (it == p_state->end()) {
-    (*p_state)[iv] = r;
-    analyzer->Bind(iv->var, r);
-  } else {
-    bool match = is_zero(it->second->min) &&
-        analyzer->CanProve(r->extent - it->second->extent == 0);
-    CHECK(match)
-        << iv
-        << " domain already inferred,"
-        << " cannot prove their extents are the same "
-        << it->second->extent << " vs " << r->extent;
-  }
-}
-
-void PassDownDomain(const Stage& stage,
-                    std::unordered_map<IterVar, Range>* p_state,
-                    arith::Analyzer* actx,
-                    bool allow_missing) {
-  auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
-    if (actx->CanProve(indexmod(a, b) == 0)) {
-      return actx->Simplify(indexdiv(a, b));
-    }
-    return actx->Simplify(indexdiv(a + (b - 1), b));
-  };
-
-  auto& state = *p_state;
-  // forwar iteration on relations
-  for (IterVarRelation rel : stage->relations) {
-    if (const SplitNode* r = rel.as<SplitNode>()) {
-      if (!state.count(r->parent)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      CHECK(!state.count(r->inner));
-      const Range& range_parent = state.at(r->parent);
-      if (r->factor.defined()) {
-        Update(p_state, r->inner,
-               Range::make_by_min_extent(0, r->factor), actx);
-        Update(p_state, r->outer,
-               Range::make_by_min_extent(
-                   0, ceil_div(range_parent->extent, r->factor)), actx);
-      } else {
-        Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
-        Update(p_state, r->inner,
-               Range::make_by_min_extent(
-                   0, ceil_div(range_parent->extent, r->nparts)), actx);
-      }
-    } else if (const FuseNode* r = rel.as<FuseNode>()) {
-      if (!state.count(r->outer) || !state.count(r->inner)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      const Range& range_outer = state.at(r->outer);
-      const Range& range_inner = state.at(r->inner);
-      state[r->fused] = Range::make_by_min_extent(
-          0, range_outer->extent * range_inner->extent);
-    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
-      if (!state.count(r->parent)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      Update(p_state, r->rebased,
-             Range::make_by_min_extent(
-                 0, state.at(r->parent)->extent), actx);
-    } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
-      Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx);
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-  // update the extents of binded threads.
-  for (auto kv : stage->iter_var_attrs) {
-    if (kv.second->bind_thread.defined()) {
-      CHECK(state.count(kv.first));
-      Update(p_state, kv.second->bind_thread, state.at(kv.first), actx);
-    }
-  }
-}
-
-void PassUpIndex(const Stage& stage,
-                 const Map<IterVar, Range>& dom_map,
-                 std::unordered_map<IterVar, PrimExpr>* p_state,
-                 bool allow_missing) {
-  auto& state = *p_state;
-  for (size_t i = stage->relations.size(); i != 0; --i) {
-    IterVarRelation rel = stage->relations[i - 1];
-    if (const SplitNode* s = rel.as<SplitNode>()) {
-      if (!state.count(s->outer) || !state.count(s->inner)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      PrimExpr outer = state.at(s->outer);
-      PrimExpr inner = state.at(s->inner);
-      PrimExpr factor = dom_map.at(s->inner)->extent;
-      PrimExpr parent_min = dom_map.at(s->parent)->min;
-      state[s->parent] = inner + outer * factor;
-      // add min if they exist
-      if (!is_zero(parent_min)) {
-        state[s->parent] = state[s->parent] + parent_min;
-      }
-    } else if (const FuseNode* s = rel.as<FuseNode>()) {
-      if (!state.count(s->fused)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      PrimExpr value = state.at(s->fused);
-      PrimExpr factor = dom_map.at(s->inner)->extent;
-      PrimExpr outer_min = dom_map.at(s->outer)->min;
-      PrimExpr inner_min = dom_map.at(s->inner)->min;
-      state[s->outer] = indexdiv(value, factor);
-      state[s->inner] = indexmod(value, factor);
-      // add min if they exist
-      if (!is_zero(outer_min)) {
-        state[s->outer] = state[s->outer] + outer_min;
-      }
-      if (!is_zero(inner_min)) {
-        state[s->inner] = state[s->inner] + inner_min;
-      }
-    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
-      if (!state.count(s->rebased)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      PrimExpr value = state.at(s->rebased);
-      PrimExpr parent_min = dom_map.at(s->parent)->min;
-      // add min if they exist
-      if (!is_zero(parent_min)) {
-        state[s->parent] = value + parent_min;
-      } else {
-        state[s->parent] = value;
-      }
-    } else if (rel.as<SingletonNode>()) {
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
-void PassDownIndex(const Stage& stage,
-                   const Map<IterVar, Range>& dom_map,
-                   std::unordered_map<IterVar, PrimExpr>* p_state,
-                   bool allow_missing) {
-  auto& state = *p_state;
-  for (IterVarRelation rel : stage->relations) {
-    if (const SplitNode* s = rel.as<SplitNode>()) {
-      if (!state.count(s->parent)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      Range r = dom_map.at(s->inner);
-      CHECK(is_zero(r->min));
-      PrimExpr parent = state.at(s->parent);
-      PrimExpr factor = r->extent;
-      state[s->outer] = indexdiv(parent, factor);
-      state[s->inner] = indexmod(parent, factor);
-    } else if (const FuseNode* s = rel.as<FuseNode>()) {
-      if (!state.count(s->inner) && !state.count(s->outer)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      PrimExpr factor = dom_map.at(s->inner)->extent;
-      PrimExpr outer_min = dom_map.at(s->outer)->min;
-      PrimExpr inner_min = dom_map.at(s->inner)->min;
-      PrimExpr inner = state.at(s->inner);
-      PrimExpr outer = state.at(s->outer);
-      CHECK(is_zero(outer_min));
-      CHECK(is_zero(inner_min));
-      state[s->fused] = outer * factor + inner;
-    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
-      if (!state.count(s->rebased)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      PrimExpr value = state.at(s->parent);
-      PrimExpr parent_min = dom_map.at(s->parent)->min;
-      CHECK(is_zero(parent_min));
-      state[s->rebased] = value;
-    } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
-      state[s->iter] = make_zero(s->iter->var.dtype());
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
-// Domain message passing.
-void PassUpDomain(const SplitNode* s,
-                  const std::unordered_map<IterVar, Range>& dom_map,
-                  const IntSet& outer,
-                  const IntSet& inner,
-                  IntSet* parent) {
-  if (dom_map.count(s->outer) &&
-      dom_map.count(s->inner) &&
-      dom_map.count(s->parent) &&
-      outer.match_range(dom_map.at(s->outer)) &&
-      inner.match_range(dom_map.at(s->inner))) {
-    *parent = IntSet::range(dom_map.at(s->parent));
-    return;
-  }
-  PrimExpr factor = dom_map.at(s->inner)->extent;
-  PrimExpr parent_min = dom_map.at(s->parent)->min;
-  CHECK(outer.defined());
-  CHECK(inner.defined());
-  CHECK(factor.defined());
-  *parent = arith::EvalSet(
-      s->outer->var * factor + s->inner->var + parent_min,
-      {{s->outer, outer}, {s->inner, inner}});
-}
-
-void PassUpDomain(const FuseNode* s,
-                  const std::unordered_map<IterVar, Range>& dom_map,
-                  const IntSet& fused,
-                  IntSet* outer,
-                  IntSet* inner) {
-  CHECK(dom_map.count(s->outer));
-  CHECK(dom_map.count(s->inner));
-  CHECK(dom_map.count(s->fused));
-
-  if (fused.match_range(dom_map.at(s->fused))) {
-    *outer = IntSet::range(dom_map.at(s->outer));
-    *inner = IntSet::range(dom_map.at(s->inner));
-    return;
-  }
-  PrimExpr outer_min = dom_map.at(s->outer)->min;
-  PrimExpr inner_min = dom_map.at(s->inner)->min;
-
-  if (fused.is_single_point()) {
-    PrimExpr value = fused.point_value();
-    PrimExpr factor = dom_map.at(s->inner)->extent;
-    PrimExpr v_outer  = indexdiv(value, factor);
-    PrimExpr v_inner  = indexmod(value, factor);
-    if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
-    if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
-    *outer = IntSet::single_point(v_outer);
-    *inner = IntSet::single_point(v_inner);
-  } else {
-    PrimExpr fused_extent = (fused.max() - fused.min() + 1);
-    PrimExpr inner_extent = dom_map.at(s->inner)->extent;
-    *outer = IntSet::interval(
-        outer_min + indexdiv(fused.min(), inner_extent),
-        outer_min + indexdiv(fused.max(), inner_extent));
-    if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) &&
-        is_zero(Simplify(indexmod(fused.min(), fused_extent)))) {
-      // fused never spans multiple rows, make a tight bounding box
-      // there may be other cases when bounding box could be tightened
-      *inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent),
-                                inner_min + indexmod(fused.max(), inner_extent));
-    } else {  // fused may span multiple rows, use full row widths
-      if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) ||
-          !is_zero(Simplify(indexmod(fused.min(), inner_extent)))) {
-        LOG(WARNING) <<
-          "fused and original axes are not aligned, this may cause redundant computations";
-      }
-      *inner = IntSet::range(dom_map.at(s->inner));
-    }
-    return;
-  }
-}
-
-void PassUpDomain(const RebaseNode* s,
-                  const std::unordered_map<IterVar, Range>& dom_map,
-                  const IntSet& rebased,
-                  IntSet* parent) {
-  CHECK(dom_map.count(s->parent));
-  if (rebased.match_range(dom_map.at(s->rebased))) {
-    *parent = IntSet::range(dom_map.at(s->parent));
-    return;
-  }
-  PrimExpr parent_min = dom_map.at(s->parent)->min;
-  *parent = arith::EvalSet(s->rebased->var + parent_min,
-                           {{s->rebased, rebased}});
-}
-
-void PassUpDomain(const Stage& stage,
-                  const std::unordered_map<IterVar, Range>& dom_map,
-                  std::unordered_map<IterVar, IntSet>* p_state) {
-  auto& state = *p_state;
-  for (size_t i = stage->relations.size(); i != 0; --i) {
-    IterVarRelation rel = stage->relations[i - 1];
-    if (const SplitNode* r = rel.as<SplitNode>()) {
-      IntSet parent;
-      PassUpDomain(r, dom_map,
-                   state.at(r->outer), state.at(r->inner),
-                   &parent);
-      state[r->parent] = parent;
-    } else if (const FuseNode* r = rel.as<FuseNode>()) {
-      IntSet outer, inner;
-      PassUpDomain(r, dom_map,
-                   state.at(r->fused),
-                   &outer, &inner);
-      state[r->outer] = outer;
-      state[r->inner] = inner;
-    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
-      IntSet parent;
-      PassUpDomain(r, dom_map,
-                   state.at(r->rebased),
-                   &parent);
-      state[r->parent] = parent;
-    } else if (rel.as<SingletonNode>()) {
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
-// Pass up bit mask with or relation.
-void PassUpBitMaskOr(const Stage& stage,
-                     std::unordered_map<IterVar, int>* p_state,
-                     bool allow_missing) {
-  auto& state = *p_state;
-  for (size_t i = stage->relations.size(); i != 0; --i) {
-    IterVarRelation rel = stage->relations[i - 1];
-    if (const SplitNode* s = rel.as<SplitNode>()) {
-      if (!state.count(s->inner) && !state.count(s->outer)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      int res = 0;
-      if (!state.count(s->parent)) res |= state[s->parent];
-      if (!state.count(s->inner)) res |= state[s->inner];
-      if (!state.count(s->outer)) res |= state[s->outer];
-      state[s->parent] = res;
-    } else if (const FuseNode* s = rel.as<FuseNode>()) {
-      if (!state.count(s->fused)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      if (!state.count(s->outer)) {
-        state[s->outer] = state[s->fused];
-      } else {
-        state[s->outer] |= state[s->fused];
-      }
-      if (!state.count(s->inner)) {
-        state[s->inner] = state[s->fused];
-      } else {
-        state[s->inner] |= state[s->fused];
-      }
-    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
-      if (!state.count(s->rebased)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      if (!state.count(s->parent)) {
-        state[s->parent] = state[s->rebased];
-      } else {
-        state[s->parent] |= state[s->rebased];
-      }
-    } else if (rel.as<SingletonNode>()) {
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
-void PassDownBitMaskOr(const Stage& stage,
-                       std::unordered_map<IterVar, int>* p_state,
-                       bool allow_missing) {
-  auto& state = *p_state;
-  for (IterVarRelation rel : stage->relations) {
-    if (const SplitNode* s = rel.as<SplitNode>()) {
-      if (!state.count(s->parent)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      if (!state.count(s->outer)) {
-        state[s->outer] = state.at(s->parent);
-      } else {
-        state[s->outer] |= state.at(s->parent);
-      }
-      if (!state.count(s->inner)) {
-        state[s->inner] = state.at(s->parent);
-      } else {
-        state[s->inner] |= state.at(s->parent);
-      }
-    } else if (const FuseNode* s = rel.as<FuseNode>()) {
-      if (!state.count(s->outer) && !state.count(s->inner)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      int res = 0;
-      if (state.count(s->outer)) res |= state.at(s->outer);
-      if (state.count(s->inner)) res |= state.at(s->inner);
-      if (state.count(s->fused)) res |= state.at(s->fused);
-      state[s->fused] = res;
-    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
-      if (!state.count(s->parent)) {
-        CHECK(allow_missing);
-        continue;
-      }
-      if (!state.count(s->rebased)) {
-        state[s->rebased] = state.at(s->parent);
-      } else {
-        state[s->rebased] |= state.at(s->parent);
-      }
-    } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
-      state[s->iter] = 0;
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
-
-/*!
- * \brief message passing to find if boundary checking on IterVar is needed.
- * \param s The stage to be used.
- * \param p_state The message passing state
- *     IterVar->flag
- */
-void PassUpBoundCheck(const Stage& s,
-                      const Map<IterVar, Range>& dom_map,
-                      std::unordered_map<IterVar, bool>* p_state,
-                      arith::Analyzer* analyzer) {
-  auto& state = *p_state;
-  for (size_t i = s->relations.size(); i != 0; --i) {
-    IterVarRelation rel = s->relations[i - 1];
-    if (const SplitNode* s = rel.as<SplitNode>()) {
-      bool outer = state.at(s->outer);
-      bool inner = state.at(s->inner);
-
-      if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
-        PrimExpr factor = dom_map.at(s->inner)->extent;
-        PrimExpr step = dom_map.at(s->outer)->extent;
-        if (outer || inner) {
-          state[s->parent] = true;
-        } else {
-          if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
-            state[s->parent] = false;
-          } else {
-            state[s->parent] = true;
-          }
-        }
-      } else {
-        state[s->parent] = true;
-      }
-    } else if (const FuseNode* s = rel.as<FuseNode>()) {
-      bool fused = state.at(s->fused);
-      state[s->outer] = fused;
-      state[s->inner] = fused;
-    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
-      state[s->parent] = state.at(s->rebased);
-    } else if (rel.as<SingletonNode>()) {
-      // nop
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
-std::vector<PrimExpr> MakeBoundCheck(
-    const Stage& stage,
-    const Map<IterVar, Range>& dom_map,
-    const std::unordered_map<IterVar, PrimExpr>& value_map,
-    bool skip_ivar_domain,
-    const std::unordered_set<IterVar>& skip_iter) {
-  arith::Analyzer analyzer;
-
-  std::unordered_map<IterVar, bool> bound_state;
-  for (IterVar iv : stage->leaf_iter_vars) {
-    bound_state[iv] = false;
-  }
-  PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
-
-  std::vector<PrimExpr> preds;
-  std::unordered_map<const VarNode*, IntSet> iset_dmap;
-
-  // setup domain map for set analysis
-  for (const auto& kv : dom_map) {
-    iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
-  }
-
-  for (const IterVar& iv : stage->all_iter_vars) {
-    if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
-    if (bound_state.at(iv)) {
-      Range dom = dom_map.at(iv);
-      PrimExpr value = value_map.at(iv) - dom->min;
-      PrimExpr vmax = EvalSet(value, iset_dmap).max();
-      if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
-        preds.emplace_back(value < dom->extent);
-      }
-    }
-  }
-  for (const IterVar& iv : stage->op->root_iter_vars()) {
-    if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
-    Range dom = dom_map.at(iv);
-    CHECK(iv->dom.defined());
-    if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
-      PrimExpr value = value_map.at(iv) - iv->dom->min;
-      IntSet s = EvalSet(value, iset_dmap);
-      PrimExpr vmin = s.min();
-      PrimExpr vmax = s.max();
-      // The range of `value` resides in [vmin, vmax]
-      if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
-        preds.emplace_back(value >= 0);
-      }
-      if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < iv->dom->extent)) {
-        preds.emplace_back(value < iv->dom->extent);
-      }
-    }
-  }
-  return preds;
-}
-}  // namespace schedule
-}  // namespace tvm
diff --git a/src/schedule/message_passing.h b/src/schedule/message_passing.h
deleted file mode 100644 (file)
index 4b81bcc..0000000
+++ /dev/null
@@ -1,133 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file message_passing.h
- * \brief Common utilities to do message passing
- *  on the schedule hyper graph.
- */
-#ifndef TVM_SCHEDULE_MESSAGE_PASSING_H_
-#define TVM_SCHEDULE_MESSAGE_PASSING_H_
-
-#include <tvm/expr.h>
-#include <tvm/schedule.h>
-#include <tvm/operation.h>
-#include <tvm/arith/analyzer.h>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-
-namespace tvm {
-namespace schedule {
-/*!
- * \brief Downward inference of domain of each IterVar.
- *  Caller set the range of the root, then the function
- *  propagates it towards the leaves.
- *
- * \param stage The stage to operate on.
- * \param p_state The state of the message passing.
- * \param analyzer Analyzer context, storing information about bounds in p_state.
- * \param allow_missing Whether allow missing value.
- */
-void PassDownDomain(
-    const Stage& stage,
-    std::unordered_map<IterVar, Range>* p_state,
-    arith::Analyzer* analyzer,
-    bool allow_missing = false);
-
-/*!
- * \param Upward inference of index of each IterVar.
- *  given index assignement of the leaves,
- *
- * \param stage The stage to operate on.
- * \param dom_map The domain map of each iteration variable's domain.
- * \param p_state The index state of each IterVar.
- * \param allow_missing Whether allow missing value.
- */
-void PassUpIndex(const Stage& stage,
-                 const Map<IterVar, Range>& dom_map,
-                 std::unordered_map<IterVar, PrimExpr>* p_state,
-                 bool allow_missing = false);
-
-/*!
- * \param Downward inference of index of each IterVar.
- *  given index assignement of roots.
- *
- * \param stage The stage to operate on.
- * \param dom_map The domain map of each iteration variable's domain.
- * \param p_state The index state of each IterVar.
- * \param allow_missing Whether allow missing value.
- */
-void PassDownIndex(const Stage& stage,
-                   const Map<IterVar, Range>& dom_map,
-                   std::unordered_map<IterVar, PrimExpr>* p_state,
-                   bool allow_missing = false);
-
-/*!
- * \param Upward inference of domain set of each IterVar.
- *  given domain assignment of the leaves,
- *
- * \param stage The stage to operate on.
- * \param dom_map The domain map of each iteration variable's maximum domain.
- * \param p_state The index state of each IterVar.
- */
-void PassUpDomain(const Stage& stage,
-                  const std::unordered_map<IterVar, Range>& dom_map,
-                  std::unordered_map<IterVar, IntSet>* p_state);
-
-/*!
- * \brief Upward message passing of bitmask with or relation.
- * \param stage The stage to operate on.
- * \param p_state The index state of each IterVar.
- * \param allow_missing Whether allow missing value.
- */
-void PassUpBitMaskOr(const Stage& stage,
-                     std::unordered_map<IterVar, int>* p_state,
-                     bool allow_missing = false);
-
-/*!
- * \brief Downward message passing of bitmask with or relation.
- * \param stage The stage to operate on.
- * \param p_state The index state of each IterVar.
- * \param allow_missing Whether allow missing value.
- */
-void PassDownBitMaskOr(const Stage& stage,
-                       std::unordered_map<IterVar, int>* p_state,
-                       bool allow_missing = false);
-
-/*!
- * \brief Create boundary check predicates given remapped value of root
- * \param stage The stage we operate on
- * \param dom_map The domain map of each value.
- * \param value_map The value map of the root iter var.
- * \param skip_ivar_domain Whether we skip check for IterVar's original domain.
- * \param skip_iter The set of variables to skip bound condition.
- * \return List of predicates that we need to check.
- */
-std::vector<PrimExpr>
-MakeBoundCheck(
-    const Stage& stage,
-    const Map<IterVar, Range>& dom_map,
-    const std::unordered_map<IterVar, PrimExpr>& value_map,
-    bool skip_ivar_domain,
-    const std::unordered_set<IterVar>& skip_iter);
-
-}  // namespace schedule
-}  // namespace tvm
-#endif  // TVM_SCHEDULE_MESSAGE_PASSING_H_
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
deleted file mode 100644 (file)
index f06cf96..0000000
+++ /dev/null
@@ -1,885 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file schedule_dataflow_rewrite.cc
- */
-#include <tvm/schedule.h>
-#include <tvm/operation.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/ir_pass.h>
-#include <unordered_set>
-#include "message_passing.h"
-#include "../pass/ir_util.h"
-#include "../arith/compute_expr.h"
-
-namespace tvm {
-
-// find first occurance location in leaf
-template<typename T>
-size_t FindNodeRef(ArrayNode* array_node, const T& v) {
-  const Object* n = v.get();
-  for (size_t i = 0; i < array_node->data.size(); ++i) {
-    if (array_node->data[i].get() == n) return i;
-  }
-  return array_node->data.size();
-}
-
-// The replacer of cache.
-class VarReplacer : public ir::StmtExprMutator {
- public:
-  explicit VarReplacer(
-      const std::unordered_map<const VarNode*, PrimExpr>& vsub)
-      : vsub_(vsub) {}
-  PrimExpr VisitExpr_(const VarNode* op) final {
-    auto it = vsub_.find(op);
-    if (it != vsub_.end()) return it->second;
-    return GetRef<PrimExpr>(op);
-  }
-
-  ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
-    // Replace free variables in combiner
-    auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) {
-      return this->VisitExpr(e);
-      });
-    auto new_result = ir::UpdateArray(combiner->result, [this] (const PrimExpr& e) {
-      return this->VisitExpr(e);
-      });
-
-    if (combiner->identity_element.same_as(new_identity) &&
-        combiner->identity_element.same_as(new_result)) {
-      return combiner;
-    } else {
-      return ir::CommReducerNode::make(
-        combiner->lhs, combiner->rhs, new_result, new_identity);
-    }
-  }
-
-  PrimExpr VisitExpr_(const ir::ReduceNode* op) final {
-    PrimExpr new_e = StmtExprMutator::VisitExpr_(op);
-    const ir::ReduceNode* new_reduce = new_e.as<ir::ReduceNode>();
-    ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
-    if (op->combiner.same_as(new_combiner)) {
-      return new_e;
-    } else {
-      return ir::ReduceNode::make(
-        new_combiner,
-        new_reduce->source,
-        new_reduce->axis,
-        new_reduce->condition,
-        new_reduce->value_index);
-    }
-  }
-
- private:
-  const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
-};
-
-PrimExpr InjectPredicate(const Array<PrimExpr>& predicates,
-                     PrimExpr body) {
-  using ir::ReduceNode;
-  using ir::SelectNode;
-  if (predicates.size() == 0) return body;
-  const ReduceNode* reduce = body.as<ReduceNode>();
-  if (reduce) {
-    auto n = make_object<ReduceNode>(*reduce);
-    n->condition = n->condition && arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr());
-    return PrimExpr(n);
-  }
-  return SelectNode::make(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()),
-                      body,
-                      make_zero(body.dtype()));
-}
-
-// Replace data flow appears in all stages given the tensor change.
-// Also update vmap if subsequent dataflow need to be replaced.
-// Need to keep an update to the date transitive closure property on the vmap by a reverse map.
-void ReplaceDataFlow(const Array<Stage>& stages,
-                     std::unordered_map<Tensor, Tensor>* vmap,
-                     std::unordered_map<Tensor, Tensor>* rvmap) {
-  for (Stage s : stages) {
-    Operation op = s->op->ReplaceInputs(s->op, *vmap);
-    if (!op.same_as(s->op)) {
-      for (int i = 0; i < op->num_outputs(); ++i) {
-        auto it = rvmap->find(s->op.output(i));
-        if (it != rvmap->end()) {
-          (*vmap)[it->second] = op.output(i);
-        } else {
-          (*vmap)[s->op.output(i)] = op.output(i);
-          (*rvmap)[op.output(i)] = s->op.output(i);
-        }
-      }
-      s->op = op;
-    }
-  }
-}
-
-inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
-  return (a->combiner.same_as(b->combiner)) &&
-         (a->source.same_as(b->source)) &&
-         (a->axis.same_as(b->axis)) &&
-         (a->condition.same_as(b->condition));
-}
-
-Tensor Schedule::cache_read(const Tensor& tensor,
-                            const std::string& scope,
-                            const Array<Operation>& readers) {
-  (*this)->InvalidateCache();
-  // create identity mapping.
-  std::ostringstream os;
-  os << tensor->op->name;
-  if (tensor->op->num_outputs() != 1) {
-    os << ".v" << tensor->value_index;
-  }
-  os << "." << scope;
-
-  std::unordered_map<Tensor, Tensor> vsub;
-  Stage s = operator[](tensor->op);
-  Tensor sugar_tensor = s->op.output(tensor->value_index);
-  Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
-      return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
-    }, os.str());
-  vsub[sugar_tensor] = cache;
-
-  std::unordered_map<Tensor, Tensor> vmap;
-  std::unordered_map<Tensor, Tensor> rvmap;
-  for (Operation op : readers) {
-    Stage s = operator[](op);
-    Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
-    CHECK(!repl_op.same_as(s->op))
-        << "Cannot find " << tensor
-        << " in the inputs of " << s->op;
-    vmap[s->op.output(0)] = repl_op.output(0);
-    rvmap[repl_op.output(0)] = s->op.output(0);
-    s->op = repl_op;
-  }
-  ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
-  ArrayNode* stages = (*this)->stages.CopyOnWrite();
-  Stage op_stage = operator[](tensor->op);
-  size_t pos = FindNodeRef(stages, op_stage);
-  Stage cache_stage = Stage(cache->op);
-  cache_stage.set_scope(scope);
-  CHECK_LT(pos, stages->data.size());
-  stages->data.insert(stages->data.begin() + pos + 1,
-                      cache_stage);
-  (*this)->stage_map.Set(cache->op, cache_stage);
-  // Update group
-  cache_stage->group = op_stage->group;
-  if (cache_stage->group.defined()) {
-    ++cache_stage->group->num_child_stages;
-  }
-  return cache;
-}
-
-template<typename OpType>
-void PrepareAxisMapping(Stage orig_stage,
-                        OpType* op,
-                        std::unordered_set<IterVar>* p_red_axis,
-                        Array<IterVar>* p_new_axis,
-                        std::unordered_map<IterVar, Range>* p_dom_map,
-                        std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
-                        std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
-                        std::vector<PrimExpr>* p_predicates) {
-  auto& red_axis = *p_red_axis;
-  auto& new_axis = *p_new_axis;
-  auto& dom_map = *p_dom_map;
-  auto& vsub = *p_vsub;
-  auto& vsub2newvar = *p_vsub2newvar;
-  auto& predicates = *p_predicates;
-  arith::Analyzer analyzer;
-
-  for (IterVar iv : op->reduce_axis) {
-    red_axis.insert(iv);
-  }
-  for (IterVar iv : op->axis) {
-    dom_map[iv] = iv->dom;
-    analyzer.Bind(iv->var, iv->dom);
-  }
-  schedule::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
-  {
-    // The source->cache
-    std::unordered_map<IterVar, PrimExpr> value_map;
-    for (IterVar iv : orig_stage->leaf_iter_vars) {
-      if (red_axis.count(iv)) continue;
-      CHECK_EQ(iv->iter_type, kDataPar)
-          << "Can only relayout with in data parallel dimensions";
-      Range dom = dom_map.at(iv);
-      IterVar new_iv = IterVarNode::make(
-          dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
-      new_axis.push_back(new_iv);
-      if (is_one(dom->min)) {
-        value_map[iv] = dom->min;
-      } else {
-        value_map[iv] = iv->var;
-        vsub2newvar[iv->var.get()] = new_iv->var;
-      }
-    }
-    // skip reduction iteration.
-    std::unordered_set<IterVar> skip_bound_check;
-    for (IterVar iv : op->reduce_axis) {
-      skip_bound_check.insert(iv);
-    }
-    schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
-    predicates = schedule::MakeBoundCheck(
-        orig_stage, dom_map, value_map, true, skip_bound_check);
-    // The root axis
-    for (IterVar iv : op->axis) {
-      if (value_map.count(iv)) {
-        vsub[iv->var.get()] = value_map.at(iv);
-      }  // to handle tensor axis
-    }
-  }
-}
-
-Array<Tensor> ReplaceOriginalOp(Schedule sch,
-                                Stage orig_stage,
-                                const std::string& scope,
-                                Operation cache_op,
-                                Operation orig_new_op,
-                                size_t tensor_size) {
-  Array<Tensor> cache_tensor_list;
-  for (size_t i = 0; i < tensor_size; i++) {
-    Tensor cache_tensor = cache_op.output(i);
-    cache_tensor_list.push_back(cache_tensor);
-  }
-  // The replace of the dataflow
-  std::unordered_map<Tensor, Tensor> vmap;
-  std::unordered_map<Tensor, Tensor> rvmap;
-  vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
-  rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
-  for (size_t i = 0; i < tensor_size; i++) {
-    vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
-    rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
-  }
-  ReplaceDataFlow(sch->stages, &vmap, &rvmap);
-  // mutate orig stage
-  orig_stage->op = orig_new_op;
-  orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
-  orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
-  orig_stage->relations = Array<IterVarRelation>();
-  // create schedule for new cached stage.
-  ArrayNode* stages = sch->stages.CopyOnWrite();
-  size_t pos = FindNodeRef(stages, orig_stage);
-  Stage cache_stage = Stage(cache_op);
-  cache_stage.set_scope(scope);
-  CHECK_LT(pos, stages->data.size());
-  stages->data.insert(stages->data.begin() + pos,
-                      cache_stage);
-  sch->stage_map.Set(cache_op, cache_stage);
-  // Update group
-  cache_stage->group = orig_stage->group;
-  if (cache_stage->group.defined()) {
-    ++cache_stage->group->num_child_stages;
-  }
-  return cache_tensor_list;
-}
-
-
-// Cache write and relayout the data according to loop pattern
-Array<Tensor> CacheWriteWithReLayout(Schedule sch,
-                                     const Array<Tensor>& tensor_array,
-                                     const std::string& scope) {
-  size_t tensor_size = tensor_array.size();
-  sch->InvalidateCache();
-  Tensor tensor = tensor_array[0];
-  Stage orig_stage = sch[tensor->op];
-  const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
-
-  std::unordered_set<IterVar> red_axis;
-  Array<IterVar> new_axis;
-  std::unordered_map<IterVar, Range> dom_map;
-
-  std::unordered_map<const VarNode*, PrimExpr> vsub;
-  std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
-  std::vector<PrimExpr> predicates;
-
-  PrepareAxisMapping(orig_stage, compute,
-    &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
-
-  PrimExpr body;
-  Array<PrimExpr> body_list;
-  const ir::ReduceNode* first_reduce = nullptr;
-  for (auto cbody : compute->body) {
-    body = VarReplacer(vsub)(cbody);
-    body = InjectPredicate(predicates, body);
-    body = VarReplacer(vsub2newvar)(body);
-    // Reduce nodes in ONE computeOp must be the same except value_index
-    // This is right only if the original body ensures Reduce nodes are the same
-    if (body->IsInstance<ir::ReduceNode>()) {
-      const ir::ReduceNode* reduce_body = body.as<ir::ReduceNode>();
-      if (first_reduce != nullptr) {
-        CHECK(ReduceEqual(reduce_body, first_reduce));
-        body = ir::ReduceNode::make(first_reduce->combiner,
-                                first_reduce->source,
-                                first_reduce->axis,
-                                first_reduce->condition,
-                                reduce_body->value_index);
-      } else {
-        first_reduce = reduce_body;
-      }
-    } else {
-      CHECK(first_reduce == nullptr)
-        << "cannot mix reduce and other node in ONE compute bodys";
-    }
-    body_list.push_back(body);
-  }
-  // The reader args
-  Array<PrimExpr> args;
-  {
-    // cache->compute
-    std::unordered_map<IterVar, PrimExpr> value_map;
-    for (IterVar iv : compute->axis) {
-      value_map[iv] = iv->var;
-    }
-    schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
-    for (IterVar iv : orig_stage->leaf_iter_vars) {
-      if (red_axis.count(iv)) continue;
-      args.push_back(value_map.at(iv));
-    }
-  }
-  Operation cache_op = ComputeOpNode::make(
-      compute->name + "." + scope, compute->tag, compute->attrs,
-      new_axis, body_list);
-
-  Array<PrimExpr> cache_expr_list;
-  for (size_t i = 0; i < tensor_size; i++) {
-    Tensor cache_tensor = cache_op.output(i);
-    cache_expr_list.push_back(cache_tensor(args));
-  }
-  Operation orig_new_op = ComputeOpNode::make(
-      compute->name, compute->tag, compute->attrs,
-      compute->axis, cache_expr_list);
-  return ReplaceOriginalOp(sch, orig_stage, scope,
-    cache_op, orig_new_op, tensor_size);
-}
-
-
-// for tensor compute op
-Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
-                                           const Array<Tensor>& tensor_array,
-                                           const std::string& scope) {
-  size_t tensor_size = tensor_array.size();
-  sch->InvalidateCache();
-  Tensor tensor = tensor_array[0];
-  Stage orig_stage = sch[tensor->op];
-  const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
-  CHECK_EQ(tensor_op->num_outputs(), 1)
-      << "cache write only support single output tensor_compute_op";
-
-  std::unordered_set<IterVar> red_axis;
-  Array<IterVar> new_axis;
-  std::unordered_map<IterVar, Range> dom_map;
-
-  std::unordered_map<const VarNode*, PrimExpr> vsub;
-  std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
-  std::vector<PrimExpr> predicates;
-
-  PrepareAxisMapping(orig_stage, tensor_op,
-    &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
-
-
-  for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
-    IterVar iv = tensor_op->axis[i];
-    IterVar new_iv = IterVarNode::make(
-      iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
-    new_axis.push_back(new_iv);
-  }
-  Array<Region> new_regions;
-  for (Region old_region : tensor_op->input_regions) {
-    Region region;
-    for (Range r : old_region) {
-      PrimExpr min = VarReplacer(vsub2newvar)(r->min);
-      PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
-      region.push_back(Range::make_by_min_extent(min, extent));
-    }
-    new_regions.push_back(region);
-  }
-
-  Array<PrimExpr> new_scalar_inputs;
-  for (PrimExpr old_input : tensor_op->scalar_inputs) {
-    new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
-  }
-
-  Operation cache_op = TensorComputeOpNode::make(
-      tensor_op->name + "." + scope, tensor_op->tag, new_axis,
-      tensor_op->reduce_axis, tensor_op->schedulable_ndim,
-      tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);
-
-  // axis will be used in generating compute op
-  Array<IterVar> compute_axis = tensor_op->axis;
-  for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
-    IterVar iv = tensor_op->axis[i];
-    IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
-    compute_axis.Set(i, aiv);
-  }
-
-  // The reader args
-  Array<PrimExpr> args;
-  {
-    // cache->compute
-    std::unordered_map<IterVar, PrimExpr> value_map;
-    for (IterVar iv : compute_axis) {
-      value_map[iv] = iv->var;
-    }
-    schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
-    for (IterVar iv : orig_stage->leaf_iter_vars) {
-      if (red_axis.count(iv)) continue;
-      args.push_back(value_map.at(iv));
-    }
-    // tensorized region axis
-    for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
-      IterVar iv = compute_axis[i];
-      args.push_back(value_map.at(iv));
-    }
-  }
-
-  Array<PrimExpr> cache_expr_list;
-  for (size_t i = 0; i < tensor_size; i++) {
-    Tensor cache_tensor = cache_op.output(i);
-    cache_expr_list.push_back(cache_tensor(args));
-  }
-  Operation orig_new_op = ComputeOpNode::make(
-      tensor_op->name, tensor_op->tag, {},
-      compute_axis, cache_expr_list);
-  return ReplaceOriginalOp(sch, orig_stage, scope,
-    cache_op, orig_new_op, tensor_size);
-}
-
-
-Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
-                             const std::string& scope) {
-  (*this)->InvalidateCache();
-  CHECK(tensor_array.size() > 0)
-      << "size of tensor_array must be greater than 0";
-  Tensor tensor = tensor_array[0];
-  Stage orig_stage = operator[](tensor->op);
-  const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
-  CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
-      << "size of input tensor list must be same as number of stage outputs";
-  for (size_t i = 1; i < tensor_array.size(); i++) {
-    Stage tmp_stage = operator[](tensor_array[i]->op);
-    CHECK(orig_stage.same_as(tmp_stage))
-        << "Input tensor list must be generated by ONE computeOp";
-  }
-  return CacheWriteWithReLayout(*this, tensor_array, scope);
-}
-
-
-Tensor Schedule::cache_write(const Tensor& tensor,
-                             const std::string& scope) {
-  // support original compute and tensor compute both
-  (*this)->InvalidateCache();
-  if (tensor->op.as<ComputeOpNode>()) {
-    return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
-  } else if (tensor->op.as<TensorComputeOpNode>()) {
-    return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
-  } else {
-    LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
-    return Tensor();
-  }
-}
-
-
-void RebaseNonZeroMinLoop(const Schedule& sch) {
-  std::unordered_map<IterVar, IterVar> rebase_map;
-  for (Stage s : sch->stages) {
-    if (s->attach_type == kInlinedAlready) continue;
-
-    auto root_iter_vars = s->op->root_iter_vars();
-    ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
-    for (IterVar iv : root_iter_vars) {
-      size_t idx = FindNodeRef(leaf_vars, iv);
-      auto it  = s->iter_var_attrs.find(iv);
-      // don;t need to rebase path that are binded.
-      if (it != s->iter_var_attrs.end() &&
-          (*it).second->bind_thread.defined()) {
-        continue;
-      }
-      if (idx < leaf_vars->data.size()) {
-        // insert rebase
-        IterVar rebased = IterVarNode::make(
-            Range(), iv->var.copy_with_suffix(""), iv->iter_type);
-        s->relations.push_back(RebaseNode::make(iv, rebased));
-        if (s->iter_var_attrs.count(iv)) {
-          s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
-        }
-        leaf_vars->data[idx] = rebased;
-        rebase_map[iv] = rebased;
-      }
-    }
-  }
-  // remap the parent relation
-  for (Stage s : sch->stages) {
-    if (s->attach_type != kScope) continue;
-    if (rebase_map.count(s->attach_ivar)) {
-      s->attach_ivar = rebase_map.at(s->attach_ivar);
-    }
-  }
-  for (Stage s : sch->groups) {
-    if (s->attach_type != kScope) continue;
-    if (rebase_map.count(s->attach_ivar)) {
-      s->attach_ivar = rebase_map.at(s->attach_ivar);
-    }
-  }
-}
-
-void InjectInline(ScheduleNode* sch) {
-  sch->InvalidateCache();
-
-  std::vector<Array<PrimExpr> > new_body(sch->stages.size());
-  std::vector<bool> changed(sch->stages.size(), false);
-  std::vector<Stmt> new_hybrid_body(sch->stages.size());
-  std::vector<bool> hybrid_changed(sch->stages.size(), false);
-  // inline all the ops
-  for (size_t i = sch->stages.size(); i != 0; --i) {
-    Stage stage = sch->stages[i - 1];
-    if (stage->attach_type == kInline) {
-      stage->attach_type = kInlinedAlready;
-      Array<Var> args;
-      PrimExpr body;
-      {
-        // setup args
-        const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
-        CHECK(compute)
-            << "can only inline compute op";
-        for (auto iv : compute->axis) {
-          args.push_back(iv->var);
-        }
-        CHECK_EQ(compute->body.size(), 1U)
-            << "can only inline compute op with 1 output";
-        body = compute->body[0];
-      }
-      for (size_t j = i; j < sch->stages.size(); ++j) {
-        Stage s = sch->stages[j];
-        const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
-        const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
-        if (compute) {
-          if (!new_body[j].size()) {
-            new_body[j] = compute->body;
-          }
-          if (new_body[j][0]->IsInstance<ir::ReduceNode>()) {
-            // specially handle reduction inline for multiplre reductions.
-            const ir::ReduceNode* reduce = new_body[j][0].as<ir::ReduceNode>();
-            for (size_t k = 1; k < new_body[j].size(); ++k) {
-              const ir::ReduceNode* reduce_ = new_body[j][k].as<ir::ReduceNode>();
-              CHECK(reduce_);
-              CHECK(ReduceEqual(reduce_, reduce))
-                  << "The Reduce inputs of ComputeOp should "
-                  << "have the same attribute except value_index";
-            }
-            PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]),
-                                        stage->op, args, body).as<ir::EvaluateNode>()->value;
-            if (!new_value.same_as(new_body[j][0])) {
-              changed[j] = true;
-              const ir::ReduceNode* r = new_value.as<ir::ReduceNode>();
-              CHECK_EQ(new_body[j].size(), r->source.size());
-              CHECK(r != nullptr);
-              for (size_t k = 0; k < new_body[j].size(); ++k) {
-                auto n = make_object<ir::ReduceNode>(*r);
-                n->value_index = static_cast<int>(k);
-                n->dtype = r->source[k].dtype();
-                new_body[j].Set(k, PrimExpr(n));
-              }
-            }
-          } else {
-            for (size_t k = 0; k < new_body[j].size(); ++k) {
-              PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]),
-                                          stage->op, args, body).as<ir::EvaluateNode>()->value;
-              if (!new_value.same_as(new_body[j][k])) {
-                new_body[j].Set(k, new_value);
-                changed[j] = true;
-              }
-            }
-          }
-        } else if (hybrid) {
-          if (!new_hybrid_body[j].defined()) {
-            new_hybrid_body[j] = hybrid->body;
-          }
-          Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body);
-          if (!new_stmt.same_as(new_hybrid_body[j])) {
-            new_hybrid_body[j] = new_stmt;
-            hybrid_changed[j] = true;
-          }
-        }
-      }
-    }
-  }
-  std::unordered_map<Tensor, Tensor> repl;
-  // rewrite dataflow
-  for (size_t i = 0; i < sch->stages.size(); ++i) {
-    Stage s = sch->stages[i];
-    if (s->attach_type == kInlinedAlready) continue;
-    if (new_body[i].size()) {
-      // Logics from ReplaceDataFlow
-      const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
-      CHECK(compute);
-      Operation op = s->op;
-      if (changed[i]) {
-        op = ComputeOpNode::make(
-            compute->name, compute->tag, compute->attrs,
-            compute->axis, new_body[i]);
-      }
-      op = op->ReplaceInputs(op, repl);
-      if (!op.same_as(s->op)) {
-        for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
-          repl[s->op.output(idx)] = op.output(idx);
-        }
-        s->op = op;
-      }
-    } else if (hybrid_changed[i]) {
-      const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
-      CHECK(hybrid);
-      Operation op = HybridOpNode::make(
-              hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
-              hybrid->outputs, new_hybrid_body[i]);
-      op = op->ReplaceInputs(op, repl);
-      for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
-        repl[s->op.output(idx)] = op.output(idx);
-      }
-      s->op = op;
-    } else {
-      Operation op = s->op->ReplaceInputs(s->op, repl);
-      if (!op.same_as(s->op)) {
-        for (int j = 0; j < op->num_outputs(); ++j) {
-          repl[s->op.output(j)] = op.output(j);
-        }
-        s->op = op;
-      }
-    }
-  }
-}
-
-Schedule Schedule::normalize() {
-  Schedule sn = copy();
-  InjectInline(sn.operator->());
-  RebaseNonZeroMinLoop(sn);
-  return sn;
-}
-
-// Handle reduction factor.
-Array<Tensor> Schedule::rfactor(const Tensor& tensor,
-                                const IterVar& axis,
-                                int factor_axis) {
-  (*this)->InvalidateCache();
-  using ir::ReduceNode;
-  CHECK_EQ(axis->iter_type, kCommReduce)
-      << "Can only factor reduction axis";
-  Stage reduce_stage = operator[](tensor->op);
-  const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
-  CHECK(compute_op) << "Can only factor ComputeOp";
-  ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
-  {
-    size_t axis_pos = FindNodeRef(leaf_vars, axis);
-    CHECK_NE(axis_pos, leaf_vars->data.size())
-        << "Cannot find IterVar " << axis << " in leaf iter vars";
-  }
-  // Find touched reduction axis.
-  std::unordered_map<IterVar, int> touch_map;
-  touch_map[axis] = 1;
-  schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
-  schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
-  // skip reduction iteration.
-  std::unordered_set<IterVar> skip_bound_check;
-  // Verify normal axis are not touched.
-  for (IterVar iv : compute_op->axis) {
-    CHECK(!touch_map.count(iv))
-        << "Factor axis touches normal axis.";
-    skip_bound_check.insert(iv);
-  }
-  // get analyzer.
-  arith::Analyzer analyzer;
-  // Get the replace index
-  std::unordered_map<IterVar, Range> dom_map;
-  std::unordered_map<IterVar, PrimExpr> value_map;
-  for (IterVar iv : compute_op->reduce_axis) {
-    if (touch_map.count(iv)) {
-      dom_map[iv] = iv->dom;
-    } else {
-      skip_bound_check.insert(iv);
-    }
-    analyzer.Bind(iv->var, iv->dom);
-  }
-  schedule::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
-  for (IterVar iv : reduce_stage->leaf_iter_vars) {
-    if (touch_map.count(iv)) {
-      Range dom = dom_map.at(iv);
-      if (is_one(dom->extent)) {
-        value_map[iv] = dom->min;
-      } else {
-        value_map[iv] = iv->var;
-      }
-    }
-  }
-  schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
-  std::vector<PrimExpr> predicates = schedule::MakeBoundCheck(
-      reduce_stage, dom_map, value_map, true, skip_bound_check);
-
-  // Get the factored op node.
-  const int factor_axis_pos = \
-      factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
-  CHECK_LE(factor_axis_pos, compute_op->axis.size());
-  auto n = make_object<ComputeOpNode>();
-  n->name = compute_op->name + ".rf";
-  {
-    // axis relacement.
-    auto iv_node = make_object<IterVarNode>();
-    iv_node->dom = dom_map.at(axis);
-    CHECK(is_zero(iv_node->dom->min))
-        << "Can only factor reduction domain starting from 0";
-    iv_node->var = axis->var;
-    iv_node->iter_type = kDataPar;
-
-    const int size = compute_op->axis.size();
-    for (int idx = 0; idx < size; ++idx) {
-      if (factor_axis_pos == idx) {
-        n->axis.push_back(IterVar(iv_node));
-      }
-      n->axis.push_back(compute_op->axis[idx]);
-    }
-    if (factor_axis_pos == size) {
-      n->axis.push_back(IterVar(iv_node));
-    }
-  }
-  // predicate generation, copy not touched axis.
-  int idx = tensor->value_index;
-  const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
-  CHECK(reduce) << "Can only rfactor non-inline reductions";
-  predicates.push_back(reduce->condition);
-  PrimExpr predicate = likely(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()));
-
-  std::unordered_map<const VarNode*, PrimExpr> vsub;
-
-  for (IterVar iv : compute_op->reduce_axis) {
-    if (!touch_map.count(iv)) {
-      n->reduce_axis.push_back(iv);
-    } else {
-      CHECK(value_map.count(iv));
-      PrimExpr index = value_map.at(iv);
-      vsub[iv->var.get()] = index;
-    }
-  }
-
-  // Copy touched axis.
-  for (IterVar iv : reduce_stage->leaf_iter_vars) {
-    if (touch_map.count(iv) && !iv.same_as(axis)) {
-      CHECK_EQ(iv->iter_type, kCommReduce);
-      auto ncpy = make_object<IterVarNode>(*iv.operator->());
-      ncpy->dom = dom_map.at(iv);
-      n->reduce_axis.push_back(IterVar(ncpy));
-    }
-  }
-  VarReplacer replacer(vsub);
-  Array<PrimExpr> new_source = ir::UpdateArray(reduce->source,
-    [&replacer] (const PrimExpr& e) { return replacer(e); });
-
-  PrimExpr new_pred = replacer(predicate);
-
-  std::vector<PrimExpr> body;
-  for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
-    body.emplace_back(ReduceNode::make(reduce->combiner,
-                                   new_source,
-                                   n->reduce_axis,
-                                   new_pred,
-                                   idx));
-  }
-  n->body = Array<PrimExpr>(body);
-  // refresh relations, keep the un-touched relations.
-  Array<IterVarRelation> rels;
-  for (IterVarRelation rel : reduce_stage->relations) {
-    bool touched = false;
-    if (const SplitNode* r = rel.as<SplitNode>()) {
-      if (touch_map.count(r->parent)) touched = true;
-    } else if (const FuseNode* r = rel.as<FuseNode>()) {
-      if (touch_map.count(r->fused)) touched = true;
-    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
-      if (touch_map.count(r->parent)) touched = true;
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-    if (!touched) {
-      rels.push_back(rel);
-    }
-  }
-  // initialize the factored stage.
-  Operation factor_op(n);
-  ArrayNode* stages = (*this)->stages.CopyOnWrite();
-  size_t stage_pos = FindNodeRef(stages, reduce_stage);
-  Stage factor_stage = Stage(factor_op);
-  factor_stage->relations = rels;
-  CHECK_LT(stage_pos, stages->data.size());
-  stages->data.insert(stages->data.begin() + stage_pos,
-                      factor_stage);
-  (*this)->stage_map.Set(factor_op, factor_stage);
-  factor_stage->group = reduce_stage->group;
-  if (factor_stage->group.defined()) {
-    ++factor_stage->group->num_child_stages;
-  }
-  // Replace the old reduction.
-  IterVar repl_red_axis = reduce_axis(
-      dom_map.at(axis), axis->var->name_hint + ".v");
-  Array<Tensor> factor_tensors;
-  Array<Tensor> old_tensors;
-  int size = factor_op->num_outputs();
-  for (int idx = 0; idx < size; ++idx) {
-    factor_tensors.push_back(factor_op.output(idx));
-    old_tensors.push_back(reduce_stage->op.output(idx));
-  }
-  Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
-    [&](const Array<Var>& i) {
-      Array<PrimExpr> indices;
-      const int idx_size = static_cast<int>(i.size());
-      for (int idx = 0; idx < idx_size; ++idx) {
-        if (factor_axis_pos == idx) {
-          indices.push_back(repl_red_axis->var);
-        }
-        indices.push_back(i[idx]);
-      }
-      if (factor_axis_pos == idx_size) {
-          indices.push_back(repl_red_axis->var);
-      }
-      Array<PrimExpr> factor_exprs;
-      for (int idx = 0; idx < size; ++idx) {
-        factor_exprs.push_back(factor_tensors[idx](indices));
-      }
-      Array<PrimExpr> reductions;
-      Array<IterVar> axis = {repl_red_axis};
-      PrimExpr cond = const_true();
-      for (int idx = 0; idx < size; ++idx) {
-        reductions.push_back(ReduceNode::make(reduce->combiner,
-          factor_exprs, axis, cond, idx));
-      }
-      return reductions;
-    }, reduce_stage->op->name + ".repl");
-
-  std::unordered_map<Tensor, Tensor> vmap;
-  std::unordered_map<Tensor, Tensor> rvmap;
-  for (int idx = 0; idx < size; ++idx) {
-    vmap[old_tensors[idx]] = repl_tensors[idx];
-    rvmap[repl_tensors[idx]] = old_tensors[idx];
-  }
-  ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
-  // revamp the reduction stage.
-  reduce_stage->op = repl_tensors[0]->op;
-  reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
-  reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
-  reduce_stage->relations = Array<IterVarRelation>();
-  return factor_tensors;
-}
-
-}  // namespace tvm
diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc
deleted file mode 100644 (file)
index fe56b66..0000000
+++ /dev/null
@@ -1,854 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file schedule_lang.cc
- */
-#include <tvm/schedule.h>
-#include <tvm/operation.h>
-#include <unordered_set>
-#include "graph.h"
-
-namespace tvm {
-
-namespace {
-
-// find first occurance location in leaf
-template<typename T>
-size_t FindNodeRef(ArrayNode* array_node, const T& v) {
-  const Object* n = v.get();
-  for (size_t i = 0; i < array_node->data.size(); ++i) {
-    if (array_node->data[i].get() == n) return i;
-  }
-  return array_node->data.size();
-}
-
-size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) {
-  size_t pos = FindNodeRef(leaf_vars, v);
-  if (pos < leaf_vars->data.size()) return pos;
-
-  if (FindNodeRef(all_vars, v) < all_vars->data.size()) {
-    LOG(FATAL) << "Operate on iter var " << v
-               << "that has already been split";
-  } else {
-    LOG(FATAL) << "Operate on iter var " << v
-               << "that is not part of the schedule";
-  }
-  return 0;
-}
-
-void Split(StageNode* self,
-           IterVar parent,
-           PrimExpr factor,
-           PrimExpr nparts,
-           IterVar* p_outer,
-           IterVar* p_inner) {
-  // Check if split is valid.
-  CHECK(parent->iter_type == kDataPar ||
-        parent->iter_type == kCommReduce ||
-        parent->iter_type == kOrdered)
-      << "Cannot split on " << IterVarType2String(parent->iter_type);
-  IterVar outer = IterVarNode::make(
-      Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
-  IterVar inner = IterVarNode::make(
-      Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
-  *p_outer = outer;
-  *p_inner = inner;
-  // The splits
-  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
-  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
-  size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
-  self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts));
-  // add vars to all vars
-  all_vars->data.push_back(outer);
-  all_vars->data.push_back(inner);
-  // replace the position.
-  leaf_vars->data.erase(leaf_vars->data.begin() + pos);
-  leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner);
-  leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer);
-}
-
-}  // namespace
-
-Stage::Stage(Operation op) {
-  auto n = make_object<StageNode>();
-  n->op = op;
-  n->origin_op = op;
-  n->all_iter_vars = op->root_iter_vars();
-  // remove opaque var from leaf.
-  Array<IterVar> clean;
-  for (IterVar iv : n->all_iter_vars) {
-    if (iv->iter_type != kOpaque) clean.push_back(iv);
-  }
-  if (clean.size() == n->all_iter_vars.size()) {
-    n->leaf_iter_vars = n->all_iter_vars;
-  } else {
-    n->leaf_iter_vars = clean;
-  }
-  data_ = std::move(n);
-}
-
-bool Stage::is_scheduled() const {
-  const StageNode* n = operator->();
-  return !(n->relations.empty() && n->attach_type == kGroupRoot &&
-           n->all_iter_vars.same_as(n->leaf_iter_vars));
-}
-
-Stage Stage::GetAttachSpec() const {
-  Stage attach_spec = *this;
-  while (attach_spec->attach_type == kGroupRoot &&
-         attach_spec->group.defined()) {
-    attach_spec = attach_spec->group;
-  }
-  return attach_spec;
-}
-
-Stage& Stage::set_scope(std::string scope) {  // NOLINT(*)
-  (*this)->scope = scope;
-  return *this;
-}
-
-Stage& Stage::compute_at(Stage parent, IterVar scope) {   // NOLINT(*)
-  CHECK_NE((*this)->attach_type, kScanUpdate)
-      << "Cannot specify compute_at for scan updates";
-  // Group constraint checking.
-  Stage group = (*this)->group;
-  if (group.defined()) {
-    Stage pg = parent->group;
-    while (pg.defined() && !pg.same_as(group)) {
-      pg = pg->group;
-    }
-    CHECK(pg.same_as(group))
-        << "Can only assign compute_at to stages within the same group";
-  }
-
-  (*this)->attach_type = kScope;
-  (*this)->attach_ivar = scope;
-  (*this)->attach_stage = parent;
-  bool found = false;
-  for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
-    if (scope == parent->leaf_iter_vars[i]) {
-      found = true; break;
-    }
-  }
-  CHECK(found)
-      << "Cannot find the axis " << scope
-      << " in parent's leaf_iter_vars"
-      << " parent=" << parent;
-  return *this;
-}
-
-Stage& Stage::compute_inline() {   // NOLINT(*)
-  CHECK_NE((*this)->attach_type, kScanUpdate)
-      << "Cannot specify compute_at for scan updates";
-  (*this)->attach_type = kInline;
-  return *this;
-}
-
-Stage& Stage::compute_root() {   // NOLINT(*)
-  CHECK_NE((*this)->attach_type, kScanUpdate)
-      << "Cannot specify compute_at for scan updates";
-  (*this)->attach_type = kGroupRoot;
-  return *this;
-}
-
-Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) {   // NOLINT(*)
-  StageNode* self = operator->();
-  CHECK(ivar->iter_type == kDataPar ||
-        ivar->iter_type == kCommReduce)
-      << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread";
-  CHECK(thread_ivar->iter_type == kThreadIndex)
-      << "Cannot rebase by " << IterVarType2String(ivar->iter_type)
-      << ", only thread axis is allowed so far";
-  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
-  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
-  FindLeafVar(all_vars, leaf_vars, ivar);
-
-  auto it = self->iter_var_attrs.find(ivar);
-  ObjectPtr<IterVarAttrNode> n;
-  if (it != self->iter_var_attrs.end()) {
-    n = make_object<IterVarAttrNode>(*(*it).second.operator->());
-    if (n->bind_thread.defined() &&
-        !n->bind_thread.same_as(thread_ivar)) {
-      LOG(WARNING) << "Axis " << ivar
-                   << " is already bind to another thread " << n->bind_thread;
-    }
-  } else {
-    n = make_object<IterVarAttrNode>();
-  }
-  n->bind_thread = thread_ivar;
-  self->iter_var_attrs.Set(ivar, IterVarAttr(n));
-  return *this;
-}
-
-Stage& Stage::env_threads(Array<IterVar> threads) {
-  StageNode* self = operator->();
-  CHECK(self->op.defined() && self->op.as<ScanOpNode>())
-      << "env_threads is only valid for composite ops such as ScanOp";
-  CHECK_EQ(self->env_threads.size(), 0U)
-      << "Already set env_threads";
-  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
-  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
-  std::vector<ObjectRef> temp;
-  for (IterVar iv : threads) {
-    temp.push_back(iv);
-  }
-  leaf_vars->data.insert(
-      leaf_vars->data.begin(), temp.begin(), temp.end());
-  all_vars->data.insert(
-      all_vars->data.end(), temp.begin(), temp.end());
-  self->env_threads = threads;
-  return *this;
-}
-
-Stage& Stage::set_store_predicate(PrimExpr predicate) {
-  StageNode* self = operator->();
-  self->store_predicate = predicate;
-  return *this;
-}
-
-Stage& Stage::split(
-    IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) {  // NOLINT(*)
-  Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
-  return *this;
-}
-
-Stage& Stage::split_by_nparts(
-    IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
-  Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
-  return *this;
-}
-
-Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) {  // NOLINT(*)
-  StageNode* self = operator->();
-  CHECK(outer->iter_type == kDataPar ||
-        outer->iter_type == kCommReduce ||
-        outer->iter_type == kOrdered)
-      << "Cannot fuse " << IterVarType2String(outer->iter_type);
-  CHECK(inner->iter_type == kDataPar ||
-        inner->iter_type == kCommReduce ||
-        inner->iter_type == kOrdered)
-      << "Cannot fuse " << IterVarType2String(inner->iter_type);
-
-  IterVarType iter_type = outer->iter_type;
-  if (inner->iter_type > iter_type) iter_type = inner->iter_type;
-  std::string fused_name =
-      outer->var->name_hint + "." + inner->var->name_hint + ".fused";
-
-  IterVar fused = IterVarNode::make(
-      Range(), Var(fused_name, outer->var.dtype()), iter_type);
-
-  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
-  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
-
-  size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner);
-  size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer);
-  if (pos_inner + 1 == pos_outer) {
-    std::swap(outer, inner);
-    std::swap(pos_inner, pos_outer);
-  }
-  self->relations.push_back(FuseNode::make(outer, inner, fused));
-  all_vars->data.push_back(fused);
-  CHECK_EQ(pos_inner, pos_outer + 1)
-      << "Can only fuse iterations that are consecutive between each other";
-  leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
-                        leaf_vars->data.begin() + pos_inner + 1);
-  leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
-                         fused);
-  *p_target = fused;
-  return *this;
-}
-
-Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) {  // NOLINT(*)
-  if (axes.size() != 0) {
-    IterVar fused = axes[0];
-    for (size_t i = 1; i < axes.size(); ++i) {
-      this->fuse(fused, axes[i], &fused);
-    }
-    *p_target = std::move(fused);
-  } else {
-    StageNode* self = operator->();
-    // special handle fuse empty array.
-    // insert at the outer most loop
-    IterVar singleton = IterVarNode::make(
-        Range::make_by_min_extent(0, 1),
-        Var("singleton", DataType::Int(32)), kDataPar);
-    self->relations.push_back(SingletonNode::make(singleton));
-    ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
-    ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
-    all_vars->data.push_back(singleton);
-    leaf_vars->data.insert(leaf_vars->data.begin(), singleton);
-    *p_target = singleton;
-  }
-  return *this;
-}
-
-Stage& Stage::reorder(const Array<IterVar>& order) {  // NOLINT(*)
-  std::unordered_set<IterVar> seen_var;
-  StageNode* self = operator->();
-  for (IterVar iv : order) {
-    CHECK(iv->iter_type == kDataPar ||
-          iv->iter_type == kCommReduce ||
-          iv->iter_type == kThreadIndex)
-        << "Cannot reorder IterVar("
-        << IterVarType2String(iv->iter_type) << ")";
-
-    CHECK_EQ(seen_var.count(iv), 0)
-        << "Same axis can not appear more than once " << iv;
-    seen_var.insert(iv);
-  }
-  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
-  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
-  std::vector<size_t> pos;
-
-  for (size_t i = 0; i < order.size(); ++i) {
-    pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
-  }
-  std::vector<ObjectRef> temp;
-  for (size_t i = 0; i < pos.size(); ++i) {
-    temp.emplace_back(leaf_vars->data[pos[i]]);
-  }
-  std::sort(pos.begin(), pos.end());
-  for (size_t i = 0; i < pos.size(); ++i) {
-    leaf_vars->data[pos[i]] = temp[i];
-  }
-  return *this;
-}
-
-Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
-                   PrimExpr x_factor, PrimExpr y_factor,
-                   IterVar* p_x_outer, IterVar* p_y_outer,
-                   IterVar* p_x_inner, IterVar* p_y_inner) {
-  split(x_parent, x_factor, p_x_outer, p_x_inner);
-  split(y_parent, y_factor, p_y_outer, p_y_inner);
-  reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
-  return *this;
-}
-
-template<typename FUpdate>
-inline void UpdateIterVarAttr(StageNode* self,
-                              IterVar var,
-                              FUpdate fupdate,
-                              bool need_leaf = true) {
-  if (need_leaf) {
-    ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
-    ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
-    FindLeafVar(all_vars, leaf_vars, var);
-  }
-  auto it = self->iter_var_attrs.find(var);
-  ObjectPtr<IterVarAttrNode> n;
-  if (it != self->iter_var_attrs.end()) {
-    n = make_object<IterVarAttrNode>(*(*it).second.operator->());
-  } else {
-    n = make_object<IterVarAttrNode>();
-  }
-  fupdate(n.get());
-  self->iter_var_attrs.Set(var, IterVarAttr(n));
-}
-
-inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
-  UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) {
-      n->iter_type = iter_type;
-    });
-}
-
-Stage& Stage::vectorize(IterVar var) {   // NOLINT(*)
-  CHECK(var->iter_type == kDataPar ||
-        var->iter_type == kOpaque ||
-        var->iter_type == kUnrolled ||
-        var->iter_type == kVectorized ||
-        var->iter_type == kTensorized ||
-        var->iter_type == kParallelized)
-      << "Cannot vectorize on " << IterVarType2String(var->iter_type);
-  SetAttrIterType(operator->(), var, kVectorized);
-  return *this;
-}
-
-Stage& Stage::tensorize(IterVar var, TensorIntrin f) {   // NOLINT(*)
-  UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) {
-      n->iter_type = kTensorized;
-      n->tensor_intrin = f;
-    });
-  return *this;
-}
-
-Stage& Stage::unroll(IterVar var) {   // NOLINT(*)
-  SetAttrIterType(operator->(), var, kUnrolled);
-  return *this;
-}
-
-Stage& Stage::parallel(IterVar var) {   // NOLINT(*)
-  SetAttrIterType(operator->(), var, kParallelized);
-  return *this;
-}
-
-Stage& Stage::pragma(IterVar var,
-                     const std::string& pragma_type,
-                     const PrimExpr& pragma_value) {   // NOLINT(*)
-  if (pragma_type == "unroll") {
-    this->unroll(var);
-  } else if (pragma_type == "vectorize") {
-    this->vectorize(var);
-  } else {
-    UpdateIterVarAttr(
-        operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
-          n->pragma_keys.push_back(ir::StringImmNode::make(pragma_type));
-          n->pragma_values.push_back(pragma_value);
-        });
-  }
-  return *this;
-}
-
-Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) {
-  StageNode *self = operator->();
-  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
-  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
-  FindLeafVar(all_vars, leaf_vars, var);
-  auto it = self->iter_var_attrs.find(var);
-  ObjectPtr<IterVarAttrNode> n;
-  if (it != self->iter_var_attrs.end()) {
-    n = make_object<IterVarAttrNode>(*(*it).second.operator->());
-  } else {
-    n = make_object<IterVarAttrNode>();
-  }
-  n->prefetch_data.push_back(tensor);
-  n->prefetch_offset.push_back(offset);
-  self->iter_var_attrs.Set(var, IterVarAttr(n));
-  return *this;
-}
-
-Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
-  StageNode *self = operator->();
-  UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) {
-      n->dim_align_factor = factor;
-      n->dim_align_offset = offset;
-    }, false);
-  return *this;
-}
-
-Stage& Stage::double_buffer() {
-  StageNode *self = operator->();
-  CHECK(!self->is_output) << "Cannot apply double buffer on output";
-  self->double_buffer = true;
-  return *this;
-}
-
-Stage& Stage::opengl() {
-  CHECK(!is_scheduled()) << "Must be a fresh schedule";
-  StageNode *self = operator->();
-
-  auto all_iter_vars = self->all_iter_vars;  // curr version of all_iter_vars
-  CHECK(!all_iter_vars.empty()) << "At least one iter var";
-
-  // Fuse all data parallel dimensions to 1.
-  IterVar fused = all_iter_vars[0];
-  for (size_t i = 1; i != all_iter_vars.size(); ++i) {
-    auto iter_var = all_iter_vars[i];
-    switch (iter_var->iter_type) {
-      case IterVarType::kDataPar: {
-        fuse(fused, all_iter_vars[i], &fused);
-        break;
-      }
-      case IterVarType::kThreadIndex: {
-        LOG(ERROR) << "A fresh schedule shouldn't have thread index iter var";
-        break;
-      }
-      case IterVarType::kCommReduce:
-      case IterVarType::kOrdered:
-      case IterVarType::kOpaque: {
-        break;
-      }
-      default: {
-        LOG(ERROR) << "Invalid iter var type "
-                   << IterVarType2String(iter_var->iter_type);
-        break;
-      }
-    }
-  }
-
-  // Bind the only dimension to threadIdx.x.
-  bind(fused, thread_axis(Range(nullptr), "threadIdx.x"));
-
-  // Mark this stage as OpenGL.
-  (*this)->is_opengl = true;
-
-  return *this;
-}
-
-Stage CopyStage(const Stage& s) {
-  ObjectPtr<StageNode> n =
-      make_object<StageNode>(*s.operator->());
-  return Stage(n);
-}
-
-Schedule Schedule::copy() const {
-  // map of stages.
-  const ScheduleNode* self = operator->();
-  std::unordered_map<Stage, Stage, ObjectHash, ObjectEqual> smap;
-  ObjectPtr<ScheduleNode> n = make_object<ScheduleNode>();
-  n->outputs = self->outputs;
-  // Copy the stages.
-  for (Stage s : self->stages) {
-    Stage scopy = CopyStage(s);
-    smap[s] = scopy;
-    n->stages.push_back(scopy);
-  }
-  for (Stage g : self->groups) {
-    Stage gcopy = CopyStage(g);
-    smap[g] = gcopy;
-    n->groups.push_back(gcopy);
-  }
-  // Remaps the reference relations.
-  for (auto kv : self->stage_map) {
-    n->stage_map.Set(kv.first, smap.at(kv.second));
-  }
-  for (Stage s : n->stages) {
-    if (s->attach_stage.defined()) {
-      CHECK(smap.find(s->attach_stage) != smap.end())
-        << s->attach_stage << " not found in " << (*this);
-      s->attach_stage = smap.at(s->attach_stage);
-    }
-    if (s->group.defined()) {
-      CHECK(smap.find(s->group) != smap.end())
-        << s->group << " not found in " << (*this);
-      s->group = smap.at(s->group);
-    }
-  }
-  for (Stage s : n->groups) {
-    if (s->attach_stage.defined()) {
-      CHECK(smap.find(s->attach_stage) != smap.end())
-        << s->attach_stage << " not found in " << (*this);
-      s->attach_stage = smap.at(s->attach_stage);
-    }
-    if (s->group.defined()) {
-      CHECK(smap.find(s->group) != smap.end())
-        << s->group << " not found in " << (*this);
-      s->group = smap.at(s->group);
-    }
-  }
-  return Schedule(n);
-}
-
-Stage Schedule::operator[](const Operation& op) {
-  auto it = (*this)->stage_map.find(op);
-  CHECK(it != (*this)->stage_map.end())
-      << "Cannot find Stage for operator " << op
-      << " in the schedule";
-  return (*it).second;
-}
-
-Stage LeastCommonAncestor(Stage g1, Stage g2) {
-  if (!g1.defined()) return g1;
-  if (!g2.defined()) return g2;
-  if (g1.same_as(g2)) return g1;
-  Stage g = g1;
-  while (g.defined()) {
-    if (g.same_as(g2)) return g2;
-    g = g->group;
-  }
-  g = g2;
-  while (g.defined()) {
-    if (g.same_as(g1)) return g1;
-    g = g->group;
-  }
-  return g;
-}
-
-Array<Tensor> RemapTensor(ScheduleNode* self,
-                          const Array<Tensor>& arr) {
-  self->InitCache();
-  const auto& op2stage_cache = self->op2stage_cache_;
-  Array<Tensor> ret;
-  for (Tensor t : arr) {
-    if (!op2stage_cache.count(t->op.get())) {
-      CHECK(self->stage_map.count(t->op))
-          << "Given tensor is not in the schedule plan";
-      t = self->stage_map[t->op]->op.output(t->value_index);
-    }
-    ret.push_back(t);
-  }
-  return ret;
-}
-
-// Group the schedule stages.
-Stage Schedule::create_group(const Array<Tensor>& outputs,
-                             const Array<Tensor>& inputs,
-                             bool include_inputs) {
-  ScheduleNode* self = operator->();
-  self->InitCache();
-  const auto& op2stage_cache = self->op2stage_cache_;
-  // Get the ops.
-  Array<Operation> ops = schedule::GetSubGraph(
-      RemapTensor(self, outputs),
-      RemapTensor(self, inputs),
-      include_inputs);
-  // local counter entry
-  // Automatically initialize to 0 during creation.
-  struct Entry {
-    int count{0};
-  };
-  // Map of group->touched counter
-  std::unordered_map<Stage, Entry, ObjectHash, ObjectEqual> counter;
-  // The parent group;
-  Stage parent_group;
-  // Detect common parent and child.
-  for (size_t i = 0; i < ops.size(); ++i) {
-    Operation op = ops[i];
-    auto it = op2stage_cache.find(op.get());
-    CHECK(it != op2stage_cache.end());
-    Stage op_group = it->second->group;
-    if (i == 0) {
-      parent_group = op_group;
-    } else {
-      parent_group = LeastCommonAncestor(parent_group, op_group);
-    }
-    if (op_group.defined()) {
-      ++counter[op_group].count;
-    }
-  }
-  // Create the new group stage.
-  Stage gstage(make_object<StageNode>());
-  gstage->group = parent_group;
-  if (parent_group.defined()) {
-    ++parent_group->num_child_stages;
-  }
-  // Propagate the counter statistics from by checking if subgroup
-  // Is full and propagate.
-  std::vector<Stage> stack;
-  for (auto &kv : counter) {
-    if (!kv.first.same_as(parent_group)) {
-      if (kv.first->num_child_stages == kv.second.count) {
-        stack.push_back(kv.first);
-      }
-    }
-  }
-  while (!stack.empty()) {
-    Stage g = stack.back();
-    stack.pop_back();
-    if (g->group.defined() && !g->group.same_as(parent_group)) {
-      Entry& e = counter[g->group];
-      ++e.count;
-      if (e.count == g->group->num_child_stages) {
-        stack.push_back(g->group);
-      }
-    }
-  }
-  // Verification and remappig the subgroups.
-  for (auto &kv : counter) {
-    if (kv.first.same_as(parent_group)) continue;
-    CHECK_EQ(kv.first->num_child_stages, kv.second.count)
-        << "Trying to group region that intersect with an already existed group";
-    if (kv.first->group.same_as(parent_group)) {
-      Stage s = kv.first;
-      s->group = gstage;
-      ++gstage->num_child_stages;
-      if (parent_group.defined()) {
-        --parent_group->num_child_stages;
-      }
-    }
-  }
-  // Remap the group of op stages.
-  for (Operation op : ops) {
-    auto it = op2stage_cache.find(op.get());
-    CHECK(it != op2stage_cache.end());
-    Stage s = it->second;
-    if (s->group.same_as(parent_group)) {
-      s->group = gstage;
-      ++gstage->num_child_stages;
-      if (parent_group.defined()) {
-        --parent_group->num_child_stages;
-      }
-    }
-  }
-  // Correct the attach to keep everything in group.
-  for (Operation op : ops) {
-    auto it = op2stage_cache.find(op.get());
-    CHECK(it != op2stage_cache.end());
-    Stage s = it->second;
-    if (s->attach_type == kScope) {
-      Stage cg = LeastCommonAncestor(s->attach_stage->group, gstage);
-      if (!cg.same_as(gstage)) {
-        LOG(WARNING) << "group invalidates some previous compute_at relation "
-                     << " and keeps things to be computed inside the group";
-        s.compute_root();
-      }
-    }
-  }
-
-  self->groups.push_back(gstage);
-  return gstage;
-}
-
-void ScheduleNode::InvalidateCache() {
-  op2stage_cache_.clear();
-}
-
-void ScheduleNode::InitCache() {
-  if (op2stage_cache_.size() == stages.size()) return;
-  InvalidateCache();
-  for (Stage s : stages) {
-    if (s->op.defined()) {
-      op2stage_cache_[s->op.get()] = s;
-    }
-  }
-  CHECK_EQ(op2stage_cache_.size(), stages.size());
-}
-
-bool ScheduleNode::Contain(const Operation& op) const {
-  return stage_map.find(op) != stage_map.end();
-}
-
-Schedule ScheduleNode::make(Array<Operation> ops) {
-  auto n = make_object<ScheduleNode>();
-  Schedule sch(n);
-  n->outputs = ops;
-  auto g = schedule::CreateReadGraph(n->outputs);
-  Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
-  // output set.
-  std::unordered_set<Operation> output_set;
-  for (Operation x : ops) {
-    output_set.insert(x);
-  }
-  for (Operation op : post_order) {
-    Stage stage(op);
-    stage->is_output = output_set.count(op) != 0;
-    n->stages.push_back(stage);
-    n->stage_map.Set(op, stage);
-    // mark scan updates.
-    if (const ScanOpNode* scan = op.as<ScanOpNode>()) {
-      Array<Tensor> inputs;
-      for (Tensor t : scan->state_placeholder) {
-        inputs.push_back(t);
-      }
-      for (Tensor t : scan->inputs) {
-        inputs.push_back(t);
-      }
-      // Create the scan group.
-      Stage scan_group = sch.create_group(scan->update, inputs, false);
-      scan_group->attach_type = kScanUpdate;
-      scan_group->attach_stage = stage;
-
-      for (size_t i = 0; i < scan->update.size(); ++i) {
-        Stage s = n->stage_map[scan->update[i]->op];
-        CHECK(scan_group.same_as(s->group));
-      }
-    }
-  }
-  return sch;
-}
-
-IterVarRelation SplitNode::make(IterVar parent,
-                                IterVar outer,
-                                IterVar inner,
-                                PrimExpr factor,
-                                PrimExpr nparts) {
-  auto n = make_object<SplitNode>();
-  n->parent = parent;
-  n->outer = outer;
-  n->inner = inner;
-  n->factor = factor;
-  n->nparts = nparts;
-  return IterVarRelation(n);
-}
-
-IterVarRelation FuseNode::make(
-    IterVar outer, IterVar inner, IterVar fused) {
-  auto n = make_object<FuseNode>();
-  n->outer = outer;
-  n->inner = inner;
-  n->fused = fused;
-  return IterVarRelation(n);
-}
-
-IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
-  auto n = make_object<RebaseNode>();
-  n->parent = parent;
-  n->rebased = rebased;
-  return IterVarRelation(n);
-}
-
-IterVarRelation SingletonNode::make(IterVar iter) {
-  auto n = make_object<SingletonNode>();
-  n->iter = iter;
-  return IterVarRelation(n);
-}
-
-TVM_REGISTER_NODE_TYPE(StageNode);
-TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
-TVM_REGISTER_NODE_TYPE(SplitNode);
-TVM_REGISTER_NODE_TYPE(FuseNode);
-TVM_REGISTER_NODE_TYPE(RebaseNode);
-TVM_REGISTER_NODE_TYPE(SingletonNode);
-TVM_REGISTER_NODE_TYPE(ScheduleNode);
-
-// Printer
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<StageNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const StageNode*>(node.get());
-    if (op->op.defined()) {
-      p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
-    } else {
-      p->stream << "group-stage(" << op << ")";
-    }
-})
-.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const IterVarAttrNode*>(node.get());
-    p->stream << IterVarType2String(op->iter_type);
-})
-.set_dispatch<SplitNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const SplitNode*>(node.get());
-    p->stream << "split(parent=";
-    p->Print(op->parent);
-    p->stream << ", outer=";
-    p->Print(op->outer);
-    p->stream << ", inner=";
-    p->Print(op->inner);
-    p->stream << ')';
-})
-.set_dispatch<FuseNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const FuseNode*>(node.get());
-    p->stream << "split(";
-    p->stream << "outer=";
-    p->Print(op->outer);
-    p->stream << ", inner=";
-    p->Print(op->inner);
-    p->stream << ", fused=";
-    p->Print(op->fused);
-    p->stream << ')';
-})
-.set_dispatch<RebaseNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const RebaseNode*>(node.get());
-    p->stream << "rebase(";
-    p->stream << "parent=";
-    p->Print(op->parent);
-    p->stream << ", rebased=";
-    p->Print(op->rebased);
-    p->stream << ')';
-})
-.set_dispatch<SingletonNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const SingletonNode*>(node.get());
-    p->stream << "singleton(";
-    p->Print(op->iter);
-    p->stream << ')';
-})
-.set_dispatch<ScheduleNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const ScheduleNode*>(node.get());
-    p->stream << "schedule(" << op << ")";
-  });
-}  // namespace tvm
diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc
deleted file mode 100644 (file)
index 1eb595c..0000000
+++ /dev/null
@@ -1,427 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file schedule_ops.cc
- */
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <tvm/ir_functor_ext.h>
-#include <tvm/operation.h>
-#include <tvm/schedule_pass.h>
-#include <utility>
-#include <unordered_map>
-#include <unordered_set>
-#include "graph.h"
-#include "../op/op_util.h"
-#include "../pass/ir_util.h"
-
-namespace tvm {
-namespace schedule {
-
-using namespace ir;
-
-Stmt MakePipeline(const Stage& s,
-                  const std::unordered_map<IterVar, Range>& dom_map,
-                  Stmt consumer,
-                  bool debug_keep_trivial_loop) {
-  Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
-  if (producer.defined()) {
-    producer = ProducerConsumerNode::make(s->op, true, producer);
-  }
-  if (s->double_buffer) {
-    producer = AttrStmtNode::make(
-        s->op, ir::attr::double_buffer_scope, 1, producer);
-  }
-  Stmt pipeline = producer;
-
-  if (consumer.defined() && !is_no_op(consumer)) {
-    consumer = ProducerConsumerNode::make(s->op, false, consumer);
-    pipeline = SeqStmt({producer, consumer});
-  }
-  pipeline = s->op->BuildRealize(s, dom_map, pipeline);
-  // use attribute to mark scope of the operation.
-  pipeline = AttrStmtNode::make(
-      s->op, ir::attr::realize_scope,
-      StringImmNode::make(s->scope),
-      pipeline);
-
-  if (s->is_opengl) {
-    pipeline = AttrStmtNode::make(
-        s->op, ir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline);
-  }
-  return pipeline;
-}
-
-// inject the operator's realization on the stmt.
-class InjectAttach : public StmtMutator {
- public:
-  InjectAttach(const Stage& stage,
-               const Stage& attach_spec,
-               const std::unordered_map<IterVar, Range>& dom_map,
-               bool debug_keep_trivial_loop)
-      : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
-        debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
-
-  Stmt VisitStmt(const Stmt& input_stmt) final {
-    CHECK(input_stmt.defined());
-    auto stmt = StmtMutator::VisitStmt(input_stmt);
-    const AttrStmtNode* op = stmt.as<AttrStmtNode>();
-    if (op != nullptr &&
-        op->attr_key == attr::loop_scope) {
-      if (attach_spec_->attach_type == kScope &&
-          op->node == attach_spec_->attach_ivar) {
-        CHECK(!found_attach)
-            << "Find IterVar" << attach_spec_->attach_ivar
-            << " in multiple places in the IR";
-        found_attach = true;
-        stmt = AttrStmtNode::make(
-            op->node, op->attr_key, op->value,
-            MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
-      }
-    }
-    return stmt;
-  }
-  // whether attach point is found
-  bool found_attach{false};
-
- private:
-  // The stage.
-  const Stage& stage_;
-  // The attach spec, may not contain op.
-  const Stage& attach_spec_;
-  // domain map
-  const std::unordered_map<IterVar, Range>& dom_map_;
-  // Whether keep trivial loops with extent of 1 during lowering.
-  // This is a debug feature for dataflow/axis analysis
-  bool debug_keep_trivial_loop_;
-};
-
-// inject the operator's realization on the stmt.
-class InjectScanStep : public StmtMutator {
- public:
-  InjectScanStep(const Stage& stage,
-                 const Operation& scan_op,
-                 const std::unordered_map<IterVar, Range>& dom_map,
-                 bool is_init,
-                 bool debug_keep_trivial_loop)
-      : stage_(stage), scan_op_(scan_op),
-        dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
-
-  Stmt VisitStmt(const Stmt& input_stmt) final {
-    CHECK(input_stmt.defined());
-    auto stmt = StmtMutator::VisitStmt(input_stmt);
-    // update
-    const AttrStmtNode* op = stmt.as<AttrStmtNode>();
-    if (op != nullptr &&
-        ((op->attr_key == attr::scan_update_scope && !is_init_) ||
-         (op->attr_key == attr::scan_init_scope && is_init_))) {
-      if (op->node.same_as(scan_op_)) {
-        found_attach = true;
-        stmt = AttrStmtNode::make(
-            op->node, op->attr_key, op->value,
-            MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
-      }
-    }
-    return stmt;
-  }
-
-  // whether attach point is found
-  bool found_attach{false};
-
- private:
-  // the operations to be carried
-  const Stage& stage_;
-  const Operation& scan_op_;
-  // domain map
-  const std::unordered_map<IterVar, Range>& dom_map_;
-  // whether it is init.
-  bool is_init_;
-  // Whether keep trivial loops with extent of 1 during lowering.
-  // This is a debug feature for dataflow/axis analysis
-  bool debug_keep_trivial_loop_;
-};
-
-// Postprocessing of schedule op
-// Replace the init and update's expression by scan's buffer.
-class SchedulePostProc : public StmtExprMutator {
- public:
-  Stmt VisitStmt_(const ProducerConsumerNode* op) final {
-    auto it = replace_op_.find(op->func.get());
-    if (it != replace_op_.end()) {
-      Stmt body = this->VisitStmt(op->body);
-      if (it->second.defined()) {
-        return ProducerConsumerNode::make(
-            it->second, op->is_producer, body);
-      } else {
-        return body;
-      }
-    } else {
-      return StmtExprMutator::VisitStmt_(op);
-    }
-  }
-  Stmt VisitStmt_(const LetStmtNode* op) final {
-    if (!HasSideEffect(op->value)) {
-      var_value_[op->var.get()] = this->VisitExpr(op->value);
-      return this->VisitStmt(op->body);
-    } else {
-      return StmtExprMutator::VisitStmt_(op);
-    }
-  }
-
-  Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == attr::loop_scope ||
-        op->attr_key == attr::scan_init_scope) {
-      return this->VisitStmt(op->body);
-    } else if (op->attr_key == attr::scan_update_scope) {
-      const ScanOpNode* scan = op->node.as<ScanOpNode>();
-      CHECK(scan);
-      var_value_[scan->scan_axis->var.get()] = op->value;
-      return this->VisitStmt(op->body);
-    } else if (op->attr_key == attr::thread_extent) {
-      // delete duplicated thread extent attr
-      auto it = thread_extent_scope_.find(op->node.get());
-      if (it != thread_extent_scope_.end()) {
-        CHECK(is_zero(ir::Simplify(it->second - op->value)));
-        return this->VisitStmt(op->body);
-      } else {
-        thread_extent_scope_[op->node.get()] = op->value;
-        Stmt ret = StmtExprMutator::VisitStmt_(op);
-        thread_extent_scope_.erase(op->node.get());
-        return ret;
-      }
-    } else if (op->attr_key == ir::attr::realize_scope ||
-               op->attr_key == ir::attr::double_buffer_scope) {
-      auto it = replace_op_.find(op->node.get());
-      if (it != replace_op_.end()) {
-        if (it->second.defined()) {
-          Stmt ret = AttrStmtNode::make(
-              it->second, op->attr_key, op->value, op->body);
-          return this->VisitStmt(ret);
-        } else {
-          return this->VisitStmt(op->body);
-        }
-      }
-    } else if (op->attr_key == ir::attr::buffer_bind_scope) {
-      Array<ObjectRef> tuple = Downcast<Array<ObjectRef> >(op->node);
-      Tensor tensor = Downcast<Tensor>(tuple[1]);
-      auto it = replace_op_.find(tensor->op.get());
-      if (it != replace_op_.end()) {
-        if (it->second.defined()) {
-          return AttrStmtNode::make(
-              Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
-              op->attr_key, op->value, this->VisitStmt(op->body));
-        } else {
-          return this->VisitStmt(op->body);
-        }
-      }
-    } else if (op->attr_key == ir::attr::buffer_dim_align) {
-      Tensor tensor = Downcast<Tensor>(op->node);
-      auto it = replace_op_.find(tensor->op.get());
-      if (it != replace_op_.end()) {
-        if (it->second.defined()) {
-          return AttrStmtNode::make(
-              it->second.output(tensor->value_index),
-              op->attr_key, op->value, this->VisitStmt(op->body));
-        } else {
-          return this->VisitStmt(op->body);
-        }
-      }
-    }
-    return StmtExprMutator::VisitStmt_(op);
-  }
-
-  Stmt VisitStmt_(const RealizeNode* op) final {
-    TensorKey key{op->func, op->value_index};
-    auto it = replace_realize_.find(key);
-    if (it != replace_realize_.end()) {
-      if (it->second.defined()) {
-        Stmt ret = RealizeNode::make(
-            it->second->op, it->second->value_index,
-            op->dtype, op->bounds, op->condition, op->body);
-        return this->VisitStmt(ret);
-      } else {
-        return this->VisitStmt(op->body);
-      }
-    } else {
-      return StmtExprMutator::VisitStmt_(op);
-    }
-  }
-
-  Stmt VisitStmt_(const ProvideNode* op) final {
-    TensorKey key{op->func, op->value_index};
-    auto it = replace_buffer_.find(key);
-    if (it != replace_buffer_.end()) {
-      const Tensor& dst = it->second;
-      Stmt ret = ProvideNode::make(
-          dst->op, dst->value_index, op->value, op->args);
-      return this->VisitStmt(ret);
-    } else {
-      return StmtExprMutator::VisitStmt_(op);
-    }
-  }
-
-  PrimExpr VisitExpr_(const CallNode* op) final {
-    if (op->call_type == CallNode::Halide) {
-      TensorKey key{op->func, op->value_index};
-      auto it = replace_buffer_.find(key);
-      if (it != replace_buffer_.end()) {
-        const Tensor& dst = it->second;
-        PrimExpr ret = CallNode::make(
-            op->dtype, dst->op->name, op->args,
-            op->call_type, dst->op, dst->value_index);
-        return this->VisitExpr(ret);
-      }
-    }
-    return StmtExprMutator::VisitExpr_(op);
-  }
-
-  PrimExpr VisitExpr_(const VarNode* op) final {
-    auto it = var_value_.find(op);
-    if (it != var_value_.end()) {
-      return it->second;
-    } else {
-      return GetRef<PrimExpr>(op);
-    }
-  }
-
-  void Init(const Schedule& sch) {
-    for (Stage s : sch->stages) {
-      for (auto kv : s->iter_var_attrs) {
-        // Update bind thread information.
-        if (kv.second->bind_thread.defined()) {
-          const Var& from = kv.first->var;
-          const Var& to = kv.second->bind_thread->var;
-          CHECK(!var_value_.count(from.get()));
-          var_value_[from.get()] = to;
-        }
-      }
-      // This must be checked for all ops, including scan.
-      if (!s->op.same_as(s->origin_op)) {
-        for (int i = 0; i < s->op->num_outputs(); ++i) {
-          Tensor target = s->origin_op.output(i);
-          AddReplace(s->op.output(i), target,
-                     target, s->origin_op);
-        }
-      }
-      // Specially add replacements for scan op.
-      if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
-        for (size_t i = 0; i < scan->update.size(); ++i) {
-          Tensor t = s->origin_op.output(i);
-          AddReplace(scan->init[i], t);
-          AddReplace(scan->update[i], t);
-          AddReplace(scan->state_placeholder[i], t);
-        }
-      }
-    }
-  }
-
- private:
-  void AddReplace(Tensor src,
-                  Tensor dst,
-                  Tensor repl_realize = Tensor(),
-                  Operation repl_op = Operation()) {
-    TensorKey key{src->op, src->value_index};
-    replace_buffer_[key] = dst;
-    replace_realize_[key] = repl_realize;
-    replace_op_[src->op.get()] = repl_op;
-  }
-  // The thread extent scope.
-  std::unordered_map<const Object*, PrimExpr> thread_extent_scope_;
-  // The scan value
-  std::unordered_map<const VarNode*, PrimExpr> var_value_;
-  // buffer replacement
-  std::unordered_map<TensorKey, Tensor> replace_buffer_;
-  // buffere realization to be replaced
-  std::unordered_map<TensorKey, Tensor> replace_realize_;
-  // replace producer consumer.
-  std::unordered_map<const Object*, Operation> replace_op_;
-};
-
-Stmt ScheduleOps(
-    Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
-  Stmt body = Stmt();
-  std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
-  // scan init and scan updates
-  std::unordered_map<Operation, Operation> scan_init;
-  for (Stage s : sch->stages) {
-    const ScanOpNode* scan = s->op.as<ScanOpNode>();
-    if (!scan) continue;
-    for (Tensor t : scan->init) {
-      if (scan_init.count(t->op)) {
-        CHECK(scan_init.at(t->op).same_as(s->op))
-            << "Scan init tensor can only belong to one scan";
-      } else {
-        scan_init[t->op] = s->op;
-      }
-    }
-  }
-  // verify correctness of group.
-  for (Stage g : sch->groups) {
-    CHECK(!g->op.defined());
-    CHECK_EQ(g->leaf_iter_vars.size(), 0U);
-  }
-  // reverse the post DFS order.
-  for (size_t i = sch->stages.size(); i != 0; --i) {
-    Stage s = sch->stages[i - 1];
-    CHECK_NE(s->attach_type, kInline)
-        << "call schedule.normalize before scheduleops";
-    CHECK(s->op.defined());
-    // no need to specify place holder op.
-    if (s->op.as<PlaceholderOpNode>()) continue;
-    // Remove grouping sugar, get the real attach spec.
-    Stage attach_spec = s.GetAttachSpec();
-
-    if (scan_init.count(s->op)) {
-      CHECK(body.defined());
-      InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
-      body = mu(std::move(body));
-      CHECK(mu.found_attach)
-          << "did not find attachment point for scan.init";
-    } else if (attach_spec->attach_type == kScanUpdate) {
-      // Handle scan update
-      CHECK(body.defined());
-      InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
-      body = mu(std::move(body));
-      CHECK(mu.found_attach)
-          << "did not find attachment point for scan.update";
-    } else if (attach_spec->attach_type == kInlinedAlready) {
-      // do nothing
-    } else if (attach_spec->attach_type == kGroupRoot) {
-      CHECK(!s->group.defined());
-      body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
-    } else {
-      CHECK_EQ(attach_spec->attach_type, kScope);
-      CHECK(body.defined());
-      InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
-      body = mutator(std::move(body));
-      CHECK(mutator.found_attach)
-          << "did not find attachment point for " << s << " in "
-          << attach_spec->attach_stage->op  << " x " << attach_spec->attach_ivar
-          << ", body:\n"
-          << body;
-    }
-  }
-  SchedulePostProc post_proc;
-  post_proc.Init(sch);
-  return post_proc(std::move(body));
-}
-
-}  // namespace schedule
-}  // namespace tvm
diff --git a/src/top/operation/compute_op.cc b/src/top/operation/compute_op.cc
new file mode 100644 (file)
index 0000000..a8c2328
--- /dev/null
@@ -0,0 +1,642 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Compute Op.
+ * \file compute_op.cc
+ */
+#include <tvm/top/operation.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_functor_ext.h>
+#include <unordered_set>
+#include <string>
+#include <utility>
+#include "compute_op.h"
+#include "op_util.h"
+#include "../schedule/message_passing.h"
+#include "../../arith/compute_expr.h"
+#include "../../arith/interval_set.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<ComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const ComputeOpNode*>(node.get());
+    p->stream << "compute(" << op->name << ", " << op << ")";
+});
+
+TVM_REGISTER_NODE_TYPE(ComputeOpNode);
+
+/// Verify if ComputeOp is valid with respect to Reduce operations.
+static void VerifyComputeOp(const ComputeOpNode *op);
+
+inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
+  return (a->combiner.same_as(b->combiner)) &&
+         (a->source.same_as(b->source)) &&
+         (a->axis.same_as(b->axis)) &&
+         (a->condition.same_as(b->condition));
+}
+
+int ComputeOpNode::num_outputs() const {
+  return body.size();
+}
+
+Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
+  if (reduce_axis.size() == 0) return axis;
+  Array<IterVar> ret = axis;
+  for (IterVar iv : reduce_axis) {
+    ret.push_back(iv);
+  }
+  return ret;
+}
+
+DataType ComputeOpNode::output_dtype(size_t idx) const {
+  CHECK_LT(idx, num_outputs());
+  return body[idx].dtype();
+}
+
+Array<PrimExpr> BaseComputeOpNode::output_shape(size_t idx) const {
+  CHECK_LT(idx, num_outputs());
+  // for now, all outputs of a BaseComputeOp have the same shape
+  Array<PrimExpr> shape;
+  for (const auto& ivar : this->axis) {
+    const Range& r = ivar->dom;
+    shape.push_back(r->extent);
+  }
+  return shape;
+}
+
+Tensor compute(Array<PrimExpr> shape,
+               FCompute fcompute,
+               std::string name,
+               std::string tag,
+               Map<std::string, ObjectRef> attrs) {
+  auto op_node = make_object<ComputeOpNode>();
+  // compute dimension.
+  size_t ndim = shape.size();
+  std::vector<IterVar> axis;
+  std::vector<Var> args;
+  for (size_t i = 0; i < ndim; ++i) {
+    std::ostringstream os;
+    os << "ax" << i;
+    axis.emplace_back(IterVarNode::make(
+        Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
+    args.push_back(axis.back()->var);
+  }
+
+  return ComputeOpNode::make(
+      name, tag, attrs, axis, {fcompute(args)}).output(0);
+}
+
+Array<Tensor> compute(Array<PrimExpr> shape,
+                      FBatchCompute fcompute,
+                      std::string name,
+                      std::string tag,
+                      Map<std::string, ObjectRef> attrs) {
+  auto op_node = make_object<ComputeOpNode>();
+  // compute dimension.
+  size_t ndim = shape.size();
+  std::vector<IterVar> axis;
+  std::vector<Var> args;
+  for (size_t i = 0; i < ndim; ++i) {
+    std::ostringstream os;
+    os << "ax" << i;
+    axis.emplace_back(IterVarNode::make(
+        Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
+    args.push_back(axis.back()->var);
+  }
+
+  Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
+  Array<Tensor> outputs;
+  for (int idx = 0; idx < op->num_outputs(); ++idx) {
+    outputs.push_back(op.output(idx));
+  }
+  return outputs;
+}
+
+Operation ComputeOpNode::make(std::string name,
+                              std::string tag,
+                              Map<std::string, ObjectRef> attrs,
+                              Array<IterVar> axis,
+                              Array<PrimExpr> body) {
+  if (!attrs.defined()) {
+    attrs = Map<std::string, ObjectRef>();
+  }
+  auto n = make_object<ComputeOpNode>();
+  n->name = std::move(name);
+  n->tag = std::move(tag);
+  n->attrs = std::move(attrs);
+  n->axis = std::move(axis);
+  n->body = std::move(body);
+  if (n->body[0]->IsInstance<ir::ReduceNode>()) {
+    const ir::ReduceNode* reduce = n->body[0].as<ir::ReduceNode>();
+    n->reduce_axis = reduce->axis;
+  }
+  VerifyComputeOp(n.get());
+  return Operation(n);
+}
+
+// The schedule related logics
+Array<Tensor> ComputeOpNode::InputTensors() const {
+  Array<Tensor> ret;
+  std::unordered_set<Tensor> visited;
+  for (auto& e : body) {
+    ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
+        const ir::CallNode *call = n.as<ir::CallNode>();
+        if (call != nullptr && call->func.defined()) {
+          Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+          if (!visited.count(t)) {
+            ret.push_back(t);
+            visited.insert(t);
+          }
+        }
+      });
+  }
+  return ret;
+}
+
+Operation ComputeOpNode::ReplaceInputs(
+    const Operation& self,
+    const std::unordered_map<Tensor, Tensor>& rmap) const {
+  CHECK_EQ(self.operator->(), this);
+  VerifyComputeOp(this);
+  Array<PrimExpr> arr;
+  if (this->body[0]->IsInstance<ir::ReduceNode>()) {
+    // Specially handle reduce so the replaced op
+    // still share all the components
+    PrimExpr new_reduce = top::ReplaceTensor(this->body[0], rmap);
+    if (!new_reduce.same_as(this->body[0])) {
+      const ir::ReduceNode* r = new_reduce.as<ir::ReduceNode>();
+      for (size_t k = 0; k < this->body.size(); ++k) {
+        auto n = make_object<ir::ReduceNode>(*r);
+        n->value_index = static_cast<int>(k);
+        n->dtype = r->source[k].dtype();
+        arr.push_back(PrimExpr(n));
+      }
+    } else {
+      arr = this->body;
+    }
+  } else {
+    arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) {
+        return top::ReplaceTensor(e, rmap);
+      });
+  }
+  if (!arr.same_as(this->body)) {
+    return ComputeOpNode::make(
+        this->name, this->tag, this->attrs, this->axis, arr);
+  } else {
+    return self;
+  }
+}
+
+void ComputeOpNode::PropBoundToInputs(
+    const Operation& self,
+    arith::Analyzer* analyzer,
+    const std::unordered_map<const VarNode*, IntSet>& dom_map,
+    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+  CHECK_EQ(self.operator->(), this);
+  auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
+    auto *call = n.as<ir::CallNode>();
+    if (call != nullptr && call->func.defined()) {
+      Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+      if (t->op.defined() && out_dom_map->count(t)) {
+        TensorDom& dom = out_dom_map->at(t);
+        for (size_t i = 0; i < t.ndim(); ++i) {
+          // We assume that the value of the argument cannot be out of bounds (otherwise it is
+          // undefined behaviour), so we can intersect the estimated set of the argument with the
+          // range expected by the tensor. However, intersection may result in overly complex
+          // expressions, so we perform a more relaxed form of intersection.
+          IntSet arg_intset = EvalSet(call->args[i], dom_map);
+          const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
+          if (arg_interval) {
+            PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
+            PrimExpr shape_i_max_value = t->shape[i] - 1;
+            PrimExpr min_value = arg_interval->min_value;
+            PrimExpr max_value = arg_interval->max_value;
+            // Prefer the shape bounds only when we can prove they are tighter.
+            if (arith::is_neg_inf(min_value) ||
+                analyzer->CanProve(shape_i_min_value >= min_value)) {
+              min_value = shape_i_min_value;
+            }
+            if (arith::is_pos_inf(max_value) ||
+                analyzer->CanProve(shape_i_max_value <= max_value)) {
+              max_value = shape_i_max_value;
+            }
+            dom.data[i].push_back(IntSet::interval(min_value, max_value));
+          } else {
+            dom.data[i].push_back(arg_intset);
+          }
+        }
+      }
+    }
+  };
+  for (auto& e : body) ir::PostOrderVisit(e, fvisit);
+}
+
+void BaseComputeOpNode::GatherBound(
+    const Operation& self,
+    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+    std::unordered_map<IterVar, Range>* out_dom_map) const {
+  CHECK_EQ(self.operator->(), this);
+  const TensorDom& tdom = tensor_dom.at(self.output(0));
+  for (size_t i = 0; i < this->axis.size(); ++i) {
+    Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
+    CHECK(!out_dom_map->count(this->axis[i]));
+    (*out_dom_map)[this->axis[i]] = r;
+  }
+  for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
+    CHECK(!out_dom_map->count(this->reduce_axis[i]));
+    (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
+  }
+}
+
+Stmt BaseComputeOpNode::BuildRealize(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& realize_map,
+    const Stmt& body) const {
+  CHECK_EQ(stage->op.get(), this);
+  Region bounds;
+  for (IterVar iv : this->axis) {
+    bounds.push_back(realize_map.at(iv));
+  }
+  Stmt realize = body;
+  for (int i = this->num_outputs(); i > 0; --i) {
+    Tensor t = stage->op.output(i-1);
+    realize = ir::RealizeNode::make(t->op, t->value_index,
+      t->dtype, bounds, const_true(), realize);
+    // alignment requirement, only useful for compute
+    for (size_t i = 0; i < num_schedulable_dims(); ++i) {
+      auto it = stage->iter_var_attrs.find(this->axis[i]);
+      if (it != stage->iter_var_attrs.end()) {
+        IterVarAttr attr = (*it).second;
+        if (attr->dim_align_factor != 0) {
+          Array<PrimExpr> tuple = {static_cast<int>(i),
+                               attr->dim_align_factor,
+                               attr->dim_align_offset};
+          realize = ir::AttrStmtNode::make(
+              t, ir::attr::buffer_dim_align,
+              CallNode::make(DataType::Handle(),
+                             ir::intrinsic::tvm_tuple,
+                             tuple, CallNode::Intrinsic),
+              realize);
+        }
+      }
+    }
+  }
+  return realize;
+}
+
+size_t ComputeOpNode::num_schedulable_dims() const {
+  return axis.size();
+}
+
+// Build a reduction body.
+void MakeReduction(const ComputeOpNode* op,
+                   const Array<Tensor>& tensors,
+                   Stmt* init,
+                   Stmt* provide) {
+  Array<PrimExpr>  args;
+  for (IterVar iv : op->axis) {
+    args.push_back(iv->var);
+  }
+  std::vector<Stmt> inits, provides;
+
+  size_t size = op->body.size();
+  const ReduceNode* reduce = op->body[0].as<ReduceNode>();
+  CHECK(reduce);
+  const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
+  CHECK(combiner);
+  Array<PrimExpr> lhs;
+  for (size_t i = 0; i < size; ++i) {
+    lhs.push_back(tensors[i](args));
+  }
+  Array<PrimExpr> init_value = combiner->identity_element;
+  Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
+  for (size_t i = 0; i < size; ++i) {
+    Tensor t = tensors[i];
+    inits.emplace_back(ProvideNode::make(
+          t->op, t->value_index, init_value[i], args));
+    provides.emplace_back(ProvideNode::make(
+          t->op, t->value_index, update_value[i], args));
+  }
+  *init = SeqStmt::Flatten(inits);
+  *provide = SeqStmt::Flatten(provides);
+  if (!is_one(reduce->condition)) {
+    *provide = IfThenElseNode::make(reduce->condition, *provide);
+  }
+}
+
+// Normal computation.
+Stmt MakeProvide(const ComputeOpNode* op,
+                 const Tensor& t) {
+  Array<PrimExpr> args;
+  for (IterVar iv : op->axis) {
+    args.push_back(iv->var);
+  }
+  return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args);
+}
+
+Stmt MakeComputeStmt(const ComputeOpNode* self,
+                     const Stage& stage,
+                     const std::unordered_map<IterVar, Range>& dom_map,
+                     bool debug_keep_trivial_loop) {
+  // grab the nest structure
+  ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
+  // Normal loop structure
+  n.init_nest.emplace_back(MakeIfNest(n.init_predicates));
+  n.main_nest.emplace_back(MakeIfNest(n.main_predicates));
+  if (self->reduce_axis.size() != 0) {
+    // make reduction.
+    Stmt init, provide;
+    Array<Tensor> source;
+    for (size_t i = 0; i < self->body.size(); ++i) {
+      source.push_back(stage->op.output(i));
+    }
+    MakeReduction(self, source, &init, &provide);
+    init = MergeNest(n.init_nest, init);
+    init = Substitute(init, n.init_vmap);
+    // common nest
+    std::vector<std::vector<Stmt> > common(
+        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
+    std::vector<std::vector<Stmt> > reduce(
+        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
+    provide = MergeNest(reduce, provide);
+    if (debug_keep_trivial_loop) {
+      provide = MergeNest(common, provide);
+    } else {
+      provide = MergeNest(common, SeqStmt::Flatten(init, provide));
+    }
+    // run substitution in the on the full nest, because  loop condition
+    // could depend on outer loops.
+    return Substitute(provide, n.main_vmap);
+  } else {
+    std::vector<Stmt> provides;
+    for (size_t i = 0; i < self->body.size(); ++i) {
+      provides.emplace_back(MakeProvide(self, stage->op.output(i)));
+    }
+    Stmt provide = SeqStmt::Flatten(provides);
+    provide = MergeNest(n.main_nest, provide);
+    // run substitution in the on the full nest, because  loop condition
+    // could depend on outer loops.
+    return Substitute(provide, n.main_vmap);
+  }
+}
+
+enum class ComputeType {
+  kNormal,
+  kCrossThreadReduction,
+  kTensorize
+};
+
+ComputeType DetectComputeType(const ComputeOpNode* self,
+                              const Stage& stage) {
+  // Verify correctness of leaf nest.
+  int normal_red = 0, thread_red = 0, tensorize = 0;
+
+  for (IterVar iv : stage->leaf_iter_vars) {
+    IterVarAttr attr;
+    auto it = stage->iter_var_attrs.find(iv);
+    if (it != stage->iter_var_attrs.end()) {
+      attr = (*it).second;
+    }
+    if (attr.defined() && attr->iter_type == kTensorized) {
+      ++tensorize;
+    }
+    if (iv->iter_type == kCommReduce) {
+      if (attr.defined() && attr->bind_thread.defined()) {
+        ++thread_red;
+      } else {
+        ++normal_red;
+      }
+    } else {
+      CHECK_EQ(thread_red, 0)
+          << "Cross thread reduce cannot swap with normal data axis";
+    }
+  }
+  if (tensorize != 0) {
+    CHECK(thread_red == 0)
+        << "Cannot mix cross thread reduction with Tensorize";
+    return ComputeType::kTensorize;
+  }
+  CHECK(normal_red == 0 || thread_red == 0)
+      << "Cannot mix normal reduction with thread reduce";
+  if (thread_red != 0) {
+    return ComputeType::kCrossThreadReduction;
+  } else {
+    return ComputeType::kNormal;
+  }
+}
+
+// implement the provide utility.
+Stmt ComputeOpNode::BuildProvide(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    bool debug_keep_trivial_loop) const {
+  CHECK_EQ(stage->op.operator->(), this);
+  ComputeType ctype = DetectComputeType(this, stage);
+  if (ctype == ComputeType::kCrossThreadReduction) {
+    // specially handle cross thread reduction.
+    return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
+  } else if (ctype == ComputeType::kTensorize) {
+    return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
+  } else {
+    return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
+  }
+}
+
+ComputeLoopNest ComputeLoopNest::make(
+    const BaseComputeOpNode* self,
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    bool debug_keep_trivial_loop) {
+  CHECK_EQ(stage->op.operator->(), self);
+  ComputeLoopNest ret;
+  // make main loop nest
+  ret.main_nest = MakeLoopNest(
+      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
+      debug_keep_trivial_loop);
+  ret.main_predicates = MakeBoundCheck(
+      stage, dom_map, ret.main_vmap, false,
+      std::unordered_set<IterVar>());
+  for (auto& e : ret.main_predicates) {
+    e = likely(e);
+  }
+  if (stage->store_predicate.defined()) {
+    ret.main_predicates.push_back(stage->store_predicate);
+  }
+  if (self->reduce_axis.size() != 0) {
+    // try to find the location to insert the initialization.
+    // Fuse the initialization and provide loop when possible.
+    std::unordered_map<IterVar, int> update_state;
+    for (IterVar iv : self->reduce_axis) {
+      update_state[iv] = 2;
+    }
+    for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
+      update_state[self->axis[i]] = 1;
+    }
+    // find which iter var is related to reduction and which is related to axis.
+    top::PassDownBitMaskOr(stage, &update_state);
+    auto leaf_iter_vars = stage->leaf_iter_vars;
+    // first first loop that is related to reduction.
+    size_t begin_loop = leaf_iter_vars.size();
+    for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
+      auto iv = leaf_iter_vars[i];
+      int flag = update_state.at(iv);
+      if ((flag & 2) != 0) {
+        begin_loop = i; break;
+      }
+      ret.init_vmap[iv] = ret.main_vmap.at(iv);
+    }
+    ret.num_common_loop = begin_loop;
+    // skip loops that are related to reduction and are unrelated to axis.
+    std::unordered_set<IterVar> skip_iter;
+    for (auto kv : update_state) {
+      int flag = kv.second;
+      if (flag == 2) skip_iter.insert(kv.first);
+    }
+    ret.init_nest = MakeLoopNest(
+        stage, dom_map, begin_loop, true,
+        skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
+    ret.init_predicates = MakeBoundCheck(
+        stage, dom_map, ret.init_vmap, true, skip_iter);
+    for (auto& e : ret.init_predicates) {
+      e = likely(e);
+    }
+  } else {
+    CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
+    ret.num_common_loop = stage->leaf_iter_vars.size();
+  }
+  // copy elison here.
+  return ret;
+}
+
+namespace {
+/*!
+ * \brief Verify if ComputeOp is valid with respect to Reduce operations.
+ *
+ *  The following two properties are verified:
+ *  (1) All Reduce operations must exist at top level.
+ *  (2) For a list of operations, if one is Reduce, then the others
+ *      must be Reduce as well; and their inputs should have the
+ *      same attribute except value_index.
+ */
+class ComputeVerifier final : protected ir::ExprVisitor {
+ public:
+  /// Special member functions
+  //@{
+  explicit ComputeVerifier(const ComputeOpNode* compute)
+      : compute_(compute), reduce_(compute->body[0].as<ir::ReduceNode>()) {}
+  virtual ~ComputeVerifier() = default;
+  ComputeVerifier(const ComputeVerifier&) = delete;
+  ComputeVerifier(ComputeVerifier&&) = delete;
+  ComputeVerifier& operator=(const ComputeVerifier&) = delete;
+  ComputeVerifier& operator=(ComputeVerifier&&) = delete;
+  //@}
+
+  /// Interface to perform compute verification
+  void Run() {
+    for (const PrimExpr e : compute_->body) {
+      // Check for consistency of top level reductions
+      const ir::ReduceNode* reduce = e.as<ir::ReduceNode>();
+      CHECK((reduce && reduce_) || (!reduce && !reduce_))
+          << "All ComputeOp should be consistent "
+          << "with being Reduce operation or not.";
+
+      if (reduce && reduce_) {
+        CHECK(ReduceEqual(reduce, reduce_))
+            << "The Reduce inputs of ComputeOp should "
+            << "have the same attribute except value_index";
+      }
+
+      level_ = 0;
+      ExprVisitor::VisitExpr(e);
+    }
+  }
+
+ protected:
+  /// Visitor implementation
+  //@{
+  void VisitExpr(const PrimExpr& n) final {
+    ++level_;
+    ExprVisitor::VisitExpr(n);
+    --level_;
+  }
+
+  void VisitExpr_(const ir::ReduceNode* op) final {
+    // Check for non top level reductions
+    CHECK(0 == level_)
+        << "Reductions are only allowed at the top level of compute. "
+        << "Please create another tensor for further composition.";
+  }
+  //@}
+
+ private:
+  const ComputeOpNode* compute_{nullptr};  ///< ComputeOpNode to verify
+  const ir::ReduceNode* reduce_{nullptr};      ///< Top level Reduce operation
+  int level_{0};                           ///< Level of op being processed
+};
+}  // namespace
+
+/// Verify if ComputeOp is valid with respect to Reduce operations.
+static void VerifyComputeOp(const ComputeOpNode* op) {
+  ComputeVerifier v(op);
+  v.Run();
+}
+
+Stmt TransformUpdate(const Stage& stage,
+                     const std::unordered_map<IterVar, Range>& dom_map,
+                     const ComputeLoopNest& n,
+                     Stmt body,
+                     Stmt update) {
+  Array<PrimExpr> conds;
+  std::unordered_set<const VarNode*> banned;
+  for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
+    IterVar iv = stage->leaf_iter_vars[i];
+    auto iit = stage->iter_var_attrs.find(iv);
+    if (iit != stage->iter_var_attrs.end()) {
+      const IterVarAttr& attr = (*iit).second;
+      if (attr->iter_type == kTensorized) {
+        break;
+      }
+    }
+    if (iv->iter_type == kCommReduce) {
+      auto vit = dom_map.find(iv);
+      CHECK(vit != dom_map.end());
+      const Range& vrange = vit->second;
+      conds.push_back(likely(iv->var > vrange->min));
+      banned.insert(iv->var.get());
+    }
+  }
+  for (const PrimExpr& pred : n.main_predicates) {
+    if (ir::ExprUseVar(pred, banned)) {
+      LOG(FATAL) << "Tensorize update transform failed, the condition "
+                 << pred << " has a conflict with the reset condition";
+    }
+  }
+
+  return IfThenElseNode::make(arith::ComputeReduce<ir::OrNode>(conds, const_true(1)),
+                          update, body);
+}
+
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/operation/compute_op.h b/src/top/operation/compute_op.h
new file mode 100644 (file)
index 0000000..093dd22
--- /dev/null
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Helper utilities to implement compute_op.
+ * \file compute_op.h
+ */
+#ifndef TVM_TOP_OPERATION_COMPUTE_OP_H_
+#define TVM_TOP_OPERATION_COMPUTE_OP_H_
+
+#include <tvm/ir.h>
+#include <tvm/expr.h>
+#include <tvm/top/operation.h>
+#include <vector>
+#include <unordered_map>
+
+namespace tvm {
+namespace top {
+// loop nest structure for general compute
+// This the loop nest structured used in compute.
+// Does not include the loop body.
+struct ComputeLoopNest {
+  // The common number of loops between init and main
+  size_t num_common_loop;
+  // predicates for the initialize loop
+  std::vector<PrimExpr> init_predicates;
+  // Initialization nest involved.
+  std::vector<std::vector<Stmt> > init_nest;
+  // Value map for the init code
+  std::unordered_map<IterVar, PrimExpr> init_vmap;
+  // Predicates for the main update loop
+  std::vector<PrimExpr> main_predicates;
+  // The general loop nest
+  std::vector<std::vector<Stmt> > main_nest;
+  // Value map for the IterVar.
+  std::unordered_map<IterVar, PrimExpr> main_vmap;
+
+  /*!
+   * \brief constructor to build ComputeOpNest
+   * \param self The pointer to compute op.
+   * \param stage The scxhedule stage.
+   * \param dom_map The domain map.
+   * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
+   * \return The constructed loop nest
+   */
+  static ComputeLoopNest make(
+      const BaseComputeOpNode* self,
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& dom_map,
+      bool debug_keep_trivial_loop);
+};
+
+/*!
+ * \brief Build body of compute for cross thread reduction pattern.
+ * \param self The pointer to ComputeOpNode
+ * \param stage The schedule stage.
+ * \param dom_map The domain map.
+ * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
+ * \return The created statement.
+ */
+Stmt MakeCrossThreadReduction(
+    const ComputeOpNode* self,
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    bool debug_keep_trivial_loop);
+
+/*!
+ * \brief Build body of compute for tensorization.
+ * \param self The pointer to ComputeOpNode
+ * \param stage The schedule stage.
+ * \param dom_map The domain map.
+ * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
+ * \return The created statement.
+ */
+Stmt MakeTensorize(const ComputeOpNode* self,
+                   const Stage& stage,
+                   const std::unordered_map<IterVar, Range>& dom_map,
+                   bool debug_keep_trivial_loop);
+
+/*!
+ * \brief Transform the update part when there is no init func in tensorizing
+ * \param stage The stage for tensorizing.
+ * \param dom_map The range of each iter var.
+ * \param n The loop nest structured used in compute.
+ * \param body The body func in tensorize intrin
+ * \param update The update func in tensorize intrin
+ * \return Transformed result.
+ */
+Stmt TransformUpdate(const Stage& stage,
+                     const std::unordered_map<IterVar, Range>& dom_map,
+                     const ComputeLoopNest& n,
+                     Stmt body,
+                     Stmt update);
+}  // namespace top
+}  // namespace tvm
+
+#endif  // TVM_TOP_OPERATION_COMPUTE_OP_H_
diff --git a/src/top/operation/cross_thread_reduction.cc b/src/top/operation/cross_thread_reduction.cc
new file mode 100644 (file)
index 0000000..bf5c9b1
--- /dev/null
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Logics related to cross thread reduction, used by ComputeOpNode.
+ * \file cross_thread_reduction.cc
+ */
+#include <tvm/ir_pass.h>
+#include "compute_op.h"
+#include "op_util.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+
+Stmt MakeCrossThreadReduction(
+    const ComputeOpNode* self,
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    bool debug_keep_trivial_loop) {
+  Array<PrimExpr>  args;
+  for (IterVar iv : self->axis) {
+    args.push_back(iv->var);
+  }
+  std::unordered_map<IterVar, PrimExpr> value_map;
+  auto nest = MakeLoopNest(
+      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop);
+  auto conds = MakeBoundCheck(
+      stage, dom_map, value_map, false,
+      std::unordered_set<IterVar>());
+
+  size_t size = self->body.size();
+  CHECK_GT(size, 0);
+  std::vector<const ReduceNode*> reduces(size);
+  for (size_t i = 0; i < size; ++i) {
+    const ReduceNode* reduce = self->body[i].as<ReduceNode>();
+    CHECK(reduce);
+    reduces[i] = reduce;
+  }
+  PrimExpr cond = reduces[0]->condition;
+  for (PrimExpr v : conds) {
+    cond = cond && v;
+  }
+  Array<PrimExpr> freduce_args;
+  freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
+  for (size_t i = 0; i < size; ++i) {
+    freduce_args.push_back(reduces[0]->source[i]);
+  }
+  freduce_args.push_back(cond);
+  std::vector<Var> res_handles(size);
+  for (size_t idx = 0; idx < size; ++idx) {
+    res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle());
+    freduce_args.push_back(res_handles[idx]);
+  }
+
+  for (IterVar iv : stage->leaf_iter_vars) {
+    if (iv->iter_type == kCommReduce) {
+      auto it = stage->iter_var_attrs.find(iv);
+      if (it != stage->iter_var_attrs.end() &&
+          (*it).second->bind_thread.defined()) {
+        IterVar tv = (*it).second->bind_thread;
+        freduce_args.push_back(tv->var);
+      }
+    }
+  }
+  // Checks for the thread.
+  std::vector<PrimExpr> thread_head_check;
+  if (stage->store_predicate.defined()) {
+    thread_head_check.emplace_back(stage->store_predicate);
+  }
+
+  Stmt reduce_body = EvaluateNode::make(CallNode::make(
+      DataType::Handle(),
+      ir::intrinsic::tvm_thread_allreduce,
+      freduce_args, CallNode::Intrinsic));
+  reduce_body = AttrStmtNode::make(
+      reduces[0]->combiner,
+      attr::reduce_scope,
+      make_zero(DataType::Handle()),
+      reduce_body);
+  std::vector<Stmt> assigns(size);
+  for (size_t idx = 0; idx < size; ++idx) {
+    DataType t = reduces[idx]->dtype;
+    assigns[idx] = ProvideNode::make(
+      stage->op, idx,
+      LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
+  }
+  Stmt assign_body = SeqStmt::Flatten(assigns);
+  assign_body = MergeNest(MakeIfNest(thread_head_check), assign_body);
+  assign_body = MergeNest(MakeIfNest(conds), assign_body);
+  Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
+  for (size_t idx = size; idx != 0; --idx) {
+    body = AllocateNode::make(
+      res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+    body = AttrStmtNode::make(
+      res_handles[idx - 1], attr::storage_scope, StringImmNode::make("local"), body);
+  }
+  body = Substitute(body, value_map);
+  return MergeNest(nest, body);
+}
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/operation/extern_op.cc b/src/top/operation/extern_op.cc
new file mode 100644 (file)
index 0000000..3fc73dc
--- /dev/null
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief External computation rule.
+ * \file extern_op.cc
+ */
+#include <tvm/top/operation.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir.h>
+#include <unordered_set>
+#include "op_util.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+// ExternOpNode
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<ExternOpNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const ExternOpNode*>(node.get());
+    p->stream << "extern(" << op->name << ", " << op << ")";
+  });
+
+TVM_REGISTER_NODE_TYPE(ExternOpNode);
+
+int ExternOpNode::num_outputs() const {
+  return static_cast<int>(output_placeholders.size());
+}
+
+Array<IterVar> ExternOpNode::root_iter_vars() const {
+  return {};
+}
+
+DataType ExternOpNode::output_dtype(size_t i) const {
+  return output_placeholders[i]->dtype;
+}
+
+Array<PrimExpr> ExternOpNode::output_shape(size_t i) const {
+  return output_placeholders[i]->shape;
+}
+
+
+Operation ExternOpNode::make(std::string name,
+                             std::string tag,
+                             Map<std::string, ObjectRef> attrs,
+                             Array<Tensor> inputs,
+                             Array<Buffer> input_placeholders,
+                             Array<Buffer> output_placeholders,
+                             Stmt body) {
+  if (!attrs.defined()) {
+    attrs = Map<std::string, ObjectRef>();
+  }
+  auto n = make_object<ExternOpNode>();
+  n->name = std::move(name);
+  n->tag = std::move(tag);
+  n->attrs = std::move(attrs);
+  CHECK_EQ(inputs.size(), input_placeholders.size());
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
+    CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size());
+    for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) {
+        CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim]));
+    }
+    CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
+  }
+  n->inputs = std::move(inputs);
+  n->input_placeholders = std::move(input_placeholders);
+  n->output_placeholders = std::move(output_placeholders);
+  n->body = std::move(body);
+  return Operation(n);
+}
+
+Array<Tensor> ExternOpNode::InputTensors() const {
+  return inputs;
+}
+
+Operation ExternOpNode::ReplaceInputs(
+    const Operation& self,
+    const std::unordered_map<Tensor, Tensor>& rmap) const {
+  CHECK_EQ(self.operator->(), this);
+  auto n = make_object<ExternOpNode>(*this);
+  n->body = ReplaceTensor(this->body, rmap);
+  for (size_t i = 0; i < n->inputs.size(); ++i) {
+    Tensor t = n->inputs[i];
+    if (rmap.count(t)) {
+      n->inputs.Set(i, rmap.at(t));
+    }
+  }
+
+  if (body.same_as(n->body) &&
+      inputs.same_as(n->inputs)) {
+    return self;
+  } else {
+    return Operation(n);
+  }
+}
+
+void ExternOpNode::PropBoundToInputs(
+    const Operation& self,
+    arith::Analyzer* analyzer,
+    const std::unordered_map<const VarNode*, IntSet>& dom_map,
+    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+  for (Tensor t : this->inputs) {
+    auto it = out_dom_map->find(t);
+    if (it == out_dom_map->end()) continue;
+    TensorDom& dom = it->second;
+    for (size_t i = 0; i < t->shape.size(); ++i) {
+      dom.data[i].emplace_back(IntSet::range(
+          Range::make_by_min_extent(
+              make_const(t->shape[i].dtype(), 0), t->shape[i])));
+    }
+  }
+}
+
+void ExternOpNode::GatherBound(
+    const Operation& self,
+    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+    std::unordered_map<IterVar, Range>* out_dom_map) const {
+}
+
+Stmt ExternOpNode::BuildRealize(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& realize_map,
+    const Stmt& body) const {
+  CHECK_EQ(stage->op.get(), this);
+  Stmt realize_body = body;
+  for (int k = 0; k < num_outputs(); ++k) {
+    Tensor t = stage->op.output(k);
+    Region bounds;
+    for (size_t i = 0; i < t->shape.size(); ++i) {
+      bounds.push_back(
+          Range::make_by_min_extent(
+              make_const(t->shape[i].dtype(), 0), t->shape[i]));
+    }
+    realize_body = ir::RealizeNode::make(
+        t->op, t->value_index, t->dtype,
+        bounds, const_true(), realize_body);
+  }
+  return realize_body;
+}
+
+Stmt ExternOpNode::BuildProvide(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    bool debug_keep_trivial_loop) const {
+  CHECK_EQ(stage->op.operator->(), this);
+  Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
+  auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
+    Array<ObjectRef> bind_spec;
+    Array<PrimExpr> tuple;
+    bind_spec.push_back(buffer);
+    bind_spec.push_back(tensor);
+    for (size_t k = 0; k < buffer->shape.size(); ++k) {
+      tuple.push_back(make_const(buffer->shape[k].dtype(), 0));
+      tuple.push_back(buffer->shape[k]);
+    }
+    ret = AttrStmtNode::make(
+        bind_spec, attr::buffer_bind_scope,
+        CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
+  };
+  for (size_t i = output_placeholders.size(); i != 0; --i) {
+    f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
+  }
+  for (size_t i = inputs.size(); i != 0; --i) {
+    f_push_bind(input_placeholders[i - 1], inputs[i - 1]);
+  }
+  return ret;
+}
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/operation/hybrid_op.cc b/src/top/operation/hybrid_op.cc
new file mode 100644 (file)
index 0000000..d959826
--- /dev/null
@@ -0,0 +1,510 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Hybrid computation rule.
+ * \file hybrid_op.cc
+ */
+#include <tvm/top/operation.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/ir_pass.h>
+#include <tvm/expr_operator.h>
+#include <unordered_set>
+#include <string>
+#include <utility>
+#include "op_util.h"
+#include "hybrid_op.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+// HybridOpNode
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<HybridOpNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const HybridOpNode*>(node.get());
+    p->stream << "hybrid(" << op->name << ", " << op << ")";
+  });
+
+TVM_REGISTER_NODE_TYPE(HybridOpNode);
+
+int HybridOpNode::num_outputs() const {
+  return static_cast<int>(outputs.size());
+}
+
+Array<IterVar> HybridOpNode::root_iter_vars() const {
+  return this->axis;
+}
+
+DataType HybridOpNode::output_dtype(size_t i) const {
+  return outputs[i]->dtype;
+}
+
+Array<PrimExpr> HybridOpNode::output_shape(size_t i) const {
+  return outputs[i]->shape;
+}
+
+
+Operation HybridOpNode::make(std::string name,
+                             std::string tag,
+                             Map<std::string, ObjectRef> attrs,
+                             Array<Tensor> inputs,
+                             Array<Tensor> outputs,
+                             Stmt body) {
+  if (!attrs.defined()) {
+    attrs = Map<std::string, ObjectRef>();
+  }
+  auto n = make_object<HybridOpNode>();
+  n->name = std::move(name);
+  n->tag = std::move(tag);
+  n->attrs = std::move(attrs);
+  n->inputs = std::move(inputs);
+  n->outputs = std::move(outputs);
+  n->axis = top::GatherLoopVars(body);
+  n->body = std::move(body);
+  Operation res = Operation(n);
+  return res;
+}
+
+Array<Tensor> HybridOpNode::InputTensors() const {
+  // Because input tensors could be potentially inlined into hybrid scripts,
+  // we need to check if all input tensors are used in the body.
+  std::unordered_set<Tensor> orig_inputs;
+  for (auto t : inputs) {
+    orig_inputs.insert(t);
+  }
+  std::unordered_set<Tensor> visited;
+  Array<Tensor> curr_inputs;
+  ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
+      const ir::CallNode *call = n.as<ir::CallNode>();
+      if (call != nullptr && call->func.defined()) {
+        Tensor t = Downcast<Operation>(call->func).output(call->value_index);
+        if (orig_inputs.count(t) && !visited.count(t)) {
+          curr_inputs.push_back(t);
+          visited.insert(t);
+        }
+      }
+  });
+  return curr_inputs;
+}
+
+Operation HybridOpNode::ReplaceInputs(
+    const Operation &self,
+    const std::unordered_map<Tensor, Tensor> &rmap) const {
+  CHECK_EQ(self.operator->(), this);
+  auto n = make_object<HybridOpNode>(*this);
+  n->body = top::ReplaceTensor(this->body, rmap);
+  for (size_t i = 0; i < n->inputs.size(); ++i) {
+    Tensor t = n->inputs[i];
+    if (rmap.count(t)) {
+      n->inputs.Set(i, rmap.at(t));
+    }
+  }
+
+  if (body.same_as(n->body) &&
+      inputs.same_as(n->inputs)) {
+    return self;
+  } else {
+    return Operation(n);
+  }
+}
+
+void HybridOpNode::PropBoundToInputs(
+    const Operation &self,
+    arith::Analyzer* analyzer,
+    const std::unordered_map<const VarNode*, IntSet> &dom_map,
+    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+  auto curr_inputs = InputTensors();
+  for (Tensor t : curr_inputs) {
+    auto it = out_dom_map->find(t);
+    if (it == out_dom_map->end()) continue;
+    TensorDom &dom = it->second;
+    for (size_t i = 0; i < t->shape.size(); ++i) {
+      dom.data[i].emplace_back(IntSet::range(
+          Range::make_by_min_extent(
+              make_const(t->shape[i].dtype(), 0), t->shape[i])));
+    }
+  }
+}
+
+void HybridOpNode::GatherBound(
+    const Operation &self,
+    const std::unordered_map<Tensor, TensorDom> &tensor_dom,
+    std::unordered_map<IterVar, Range>* out_dom_map) const {
+  for (auto iter_var : axis) {
+    CHECK(!out_dom_map->count(iter_var));
+    out_dom_map->operator[](iter_var) = iter_var->dom;
+  }
+}
+
+Stmt HybridOpNode::BuildRealize(
+    const Stage &stage,
+    const std::unordered_map<IterVar, Range> &realize_map,
+    const Stmt &body) const {
+  // TODO(@were): Add attribute inject here and remove it from hybrid parser.
+  CHECK_EQ(stage->op.get(), this);
+  Stmt realize_body = body;
+  for (int k = 0; k < num_outputs(); ++k) {
+    Tensor t = stage->op.output(k);
+    Region bounds;
+    for (size_t i = 0; i < t->shape.size(); ++i) {
+      bounds.push_back(
+          Range::make_by_min_extent(
+              make_const(t->shape[i].dtype(), 0), t->shape[i]));
+    }
+    realize_body = ir::RealizeNode::make(
+        t->op, t->value_index, t->dtype,
+        bounds, const_true(), realize_body);
+  }
+  return realize_body;
+}
+
+Stmt HybridOpNode::BuildProvide(
+    const Stage &stage,
+    const std::unordered_map<IterVar, Range> &dom_map,
+    bool debug_keep_trivial_loop) const {
+  CHECK_EQ(stage->op.operator->(), this);
+  Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
+  std::unordered_map<Tensor, Tensor> rmap;
+  for (int i = 0; i < this->num_outputs(); ++i) {
+    rmap[outputs[i]] = stage->op.output(i);
+  }
+  auto n = make_object<HybridOpNode>(*this);
+  /* This is a story little bit complicated.
+   * The following two lines of codes replace output tensors' usage.
+   * This is the simplest way I (@were) can come up with to glue
+   * hybrid operation node to TVM op system.
+   * In hybrid script all the tensors, especially the output tensors,
+   * have their own names defined by the users. However, In TVM
+   * conventional ops:
+   *   1. Output tensors refer the corresponding op node so that the output
+   *      tensors have the same names as the operation produces them.
+   *   2. Once OpNode is wrapped up by an Operation node, it is finalized.
+   *      Later access will be from a const OpNode*.
+   * This is a chicken-egg paradox. It is impossible to put the output
+   * tensors into the function body without forming the op node. The
+   * function body is immutable after the node is formed.
+   *
+   * Finally, I decided to resolve this issue "lazily". During the
+   * pipeline of compilation, this stage is a very preliminary stage.
+   * Technically, it is before Phase 0. The actual tensors will be replaced
+   * here.
+   * Thus, the operation body is slightly different from the Phase 0 body.
+   * This is a major difference that HybridOpNode is NOT the same as
+   * ExternOpNode.
+   * */
+  ret = top::ReplaceTensor(ret, rmap);
+  ret = top::ReplaceProvideTensor(ret, rmap);
+
+  ret = top::ApplySchedule(stage, dom_map, ret);
+  return ret;
+}
+
+Stmt ApplyLoopShapes(const Stage &stage,
+                     const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
+  class LoopSpliter : public StmtExprMutator {
+    PrimExpr factor;
+    const VarNode *parent;
+    IterVar inner, outer;
+
+   public:
+    bool splitted;
+    LoopSpliter(const SplitNode *split,
+                const std::unordered_map<IterVar, Range> &dom_map) :
+      factor(split->factor), splitted(false) {
+      parent = split->parent->var.get();
+
+      auto &inner_ = split->inner;
+      CHECK(dom_map.count(inner_));
+      auto &inner_dom = dom_map.find(inner_)->second;
+      CHECK(is_const_int(inner_dom->min, 0));
+
+      auto &outer_ = split->outer;
+      CHECK(dom_map.count(outer_));
+      auto &outer_dom = dom_map.find(outer_)->second;
+      CHECK(is_const_int(outer_dom->min, 0));
+
+      inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type);
+      outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
+    }
+
+    Stmt VisitStmt_(const ForNode *op) final {
+      if (op->loop_var.get() == parent) {
+        std::unordered_map<const VarNode *, PrimExpr> rmap;
+        rmap[op->loop_var.get()] = inner + outer * factor;
+        Stmt ret = ir::Substitute(op->body, rmap);
+        PrimExpr cond = likely(outer * factor < (op->extent - inner));
+        ret = IfThenElseNode::make(cond, ret);
+        ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent,
+                        IterVarTypeToForType(inner->iter_type), op->device_api, ret);
+        ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent,
+                        IterVarTypeToForType(outer->iter_type), op->device_api, ret);
+        splitted = true;
+        return ret;
+      }
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  };
+
+  class LoopFuser : public StmtExprMutator {
+    const IterVar &parent;
+    const VarNode *inner;
+    const VarNode *outer;
+    bool under_outer;
+    PrimExpr extent;
+
+   public:
+    bool fused;
+    explicit LoopFuser(const FuseNode *fuse_)
+      : parent(fuse_->fused), inner(fuse_->inner->var.get()),
+        outer(fuse_->outer->var.get()), under_outer(false),
+        extent(0), fused(false) {}
+
+    // TODO(@were): Handle imperfect loops
+    Stmt VisitStmt_(const ForNode* op) final {
+      if (op->loop_var.get() == inner) {
+        CHECK(under_outer);
+        std::unordered_map<const VarNode *, PrimExpr> rmap;
+        rmap[op->loop_var.get()] = indexmod(parent, op->extent);
+        extent = op->extent;
+        fused = true;
+        return ir::Substitute(op->body, rmap);
+      } else if (op->loop_var.get() == outer) {
+        under_outer = true;
+        Stmt body = this->VisitStmt(op->body);
+        std::unordered_map<const VarNode *, PrimExpr> rmap;
+        rmap[op->loop_var.get()] = indexdiv(parent, extent);
+        body = ir::Substitute(body, rmap);
+        under_outer = false;
+        return ForNode::make(parent->var, PrimExpr(0), extent * op->extent,
+                         op->for_type, op->device_api, body);
+      } else if (under_outer) {
+        Stmt body = this->VisitStmt(op->body);
+        std::unordered_map<const VarNode *, PrimExpr> rmap;
+        rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
+        body = ir::Substitute(body, rmap);
+        extent = extent * op->extent;
+        return body;
+      }
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  };
+
+  for (auto &rel : stage->relations) {
+    if (const SplitNode *split = rel.as<SplitNode>()) {
+      LoopSpliter Spliter(split, dom_map);
+      stmt = Spliter(stmt);
+      CHECK(Spliter.splitted);
+    } else if (const FuseNode *fuse = rel.as<FuseNode>()) {
+      LoopFuser Fuser(fuse);
+      stmt = Fuser(stmt);
+      CHECK(Fuser.fused);
+    }
+  }
+
+  return stmt;
+}
+
+Stmt ApplyLoopAnnotations(const Stage &stage,
+                          const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
+  class LoopAnnotator : public StmtMutator {
+    const VarNode *var;
+    const IterVarAttr &attr;
+
+   public:
+    LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
+
+    Stmt VisitStmt_(const ForNode *op) final {
+      if (op->loop_var.get() == var) {
+        if (attr->bind_thread.defined()) {
+          const auto &iter_var = attr->bind_thread;
+          if (iter_var->dom.defined()) {
+            CHECK(is_const_int(iter_var->dom->min, 0));
+            CHECK(Equal(iter_var->dom->extent, op->extent))
+              << "Thread extent and loop extent mismatch!\n";
+          }
+          std::unordered_map<const VarNode *, PrimExpr> rmap;
+          rmap[op->loop_var.get()] = iter_var;
+          Stmt body = ir::Substitute(op->body, rmap);
+          return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
+        } else {
+          return ForNode::make(op->loop_var, op->min, op->extent,
+                           IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
+        }
+      }
+      return StmtMutator::VisitStmt_(op);
+    }
+  };
+
+  for (auto &iter_var : stage->leaf_iter_vars) {
+    bool need_change = false;
+    int found = 0;
+
+    const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
+    const VarNode *var = actual->var.get();
+    ForType expected = IterVarTypeToForType(iter_var->iter_type);
+    IterVarAttr attr;
+    if (stage->iter_var_attrs.count(iter_var)) {
+      attr = stage->iter_var_attrs[iter_var];
+      expected = IterVarTypeToForType(attr->iter_type);
+    }
+
+    PostOrderVisit(stmt,
+    [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
+      if (const ForNode *op = node.as<ForNode>()) {
+        if (op->loop_var.get() == var) {
+          ++found;
+          need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
+        }
+      }
+    });
+
+    CHECK_EQ(found, 1) << " iter var should be found exactly once!";
+    if (need_change) {
+      stmt = LoopAnnotator(var, attr)(std::move(stmt));
+    }
+  }
+  return stmt;
+}
+
+Stmt ApplyLoopOrder(const Stage &stage,
+                    const std::unordered_map<IterVar, Range> &dom_map,
+                    const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
+  std::vector<const VarNode*> current_order;
+  PostOrderVisit(stmt, [&current_order](const ObjectRef& node) {
+    if (const ForNode *op = node.as<ForNode>())
+      current_order.push_back(op->loop_var.get());
+  });
+  std::reverse(current_order.begin(), current_order.end());
+  auto &required_ord = stage->leaf_iter_vars;
+  CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
+  std::unordered_map<const VarNode *, IterVar> reorder;
+  bool need_reorder = false;
+  for (size_t i = 0; i < current_order.size(); ++i) {
+    auto &current = current_order[i];
+    const IterVar &iter_var = required_ord[i];
+    const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
+    CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n";
+    reorder[current] = required;
+    if (current != required->var.get()) {
+      need_reorder = true;
+    }
+  }
+
+  class LoopReorder : public StmtMutator {
+    const Stage &stage;
+    const std::unordered_map<IterVar, Range> &dom_map;
+    const std::unordered_map<const VarNode *, IterVar> &reorder;
+
+   public:
+    LoopReorder(const Stage &stage,
+                const std::unordered_map<IterVar, Range> &dom_map,
+                const std::unordered_map<const VarNode*, IterVar> &reorder)
+      : stage(stage), dom_map(dom_map), reorder(reorder) {}
+
+    Stmt VisitStmt_(const ForNode* op) final {
+      // Reorder from in to out
+      Stmt body_ = this->VisitStmt(op->body);
+      CHECK(reorder.count(op->loop_var.get()));
+      auto target = reorder.find(op->loop_var.get())->second;
+      if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
+        return GetRef<Stmt>(op);
+      const Stmt &body = op->body.same_as(body_) ? op->body : body_;
+      ForType for_type = IterVarTypeToForType(target->iter_type);
+      if (stage->iter_var_attrs.count(target)) {
+        for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
+      }
+      const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
+      return ForNode::make(target->var, range->min, range->extent,
+                       for_type, DeviceAPI::None, body);
+    }
+  };
+
+  if (need_reorder)
+    return LoopReorder(stage, dom_map, reorder)(stmt);
+
+  return stmt;
+}
+
+Stmt ApplySchedule(const Stage &stage,
+                   const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
+  // TODO(@were): Eliminate loop rebase in script parser and move the burden here
+  // Gather rebased variables
+  std::unordered_map<IterVar, IterVar> rebased;
+  for (auto rel : stage->relations) {
+    if (const auto* rebase = rel.as<RebaseNode>()) {
+      rebased[rebase->rebased] = rebase->parent;
+      CHECK(rebase->parent->dom.defined());
+      CHECK(dom_map.count(rebase->rebased));
+    }
+  }
+  stmt = ApplyLoopShapes(stage, dom_map, stmt);
+  stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt);
+  stmt = ApplyLoopAnnotations(stage, rebased, stmt);
+  return stmt;
+}
+
+std::vector<IterVar> GatherLoopVars(Stmt stmt) {
+  // TODO(@were): Write a comprehensive pass to analyze iter var types
+  std::vector<IterVar> res_;
+  PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
+    if (const ForNode *op = node.as<ForNode>()) {
+      Var loop_var(op->loop_var);
+      Range dom = Range::make_by_min_extent(op->min, op->extent);
+      res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type)));
+    }
+  });
+  std::reverse(res_.begin(), res_.end());
+  return res_;
+}
+
+// replacer to replace tensors' usage in Provide
+class ProviderReplacer : public ir::StmtMutator {
+ public:
+  explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
+      : vmap_(vmap) {}
+
+  Stmt VisitStmt_(const ir::ProvideNode* op) final {
+    Tensor t = Downcast<Operation>(op->func).output(op->value_index);
+    auto it = vmap_.find(t);
+    if (it != vmap_.end()) {
+      Stmt ret = ir::ProvideNode::make(
+        it->second->op, it->second->value_index, op->value, op->args);
+      found = true;
+      return this->VisitStmt(ret);
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  // whether it is found.
+  bool found{false};
+
+ private:
+  const std::unordered_map<Tensor, Tensor> &vmap_;
+};
+
+Stmt ReplaceProvideTensor(Stmt stmt,
+                   const std::unordered_map<Tensor, Tensor> &replace) {
+  ProviderReplacer repl(replace);
+  Stmt ret = repl(stmt);
+  return repl.found ? ret : stmt;
+}
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/operation/hybrid_op.h b/src/top/operation/hybrid_op.h
new file mode 100644 (file)
index 0000000..c4586cb
--- /dev/null
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Helper utilities to implement hybrid_op.
+ * \file hybrid_op.h
+ */
+#ifndef TVM_TOP_OPERATION_HYBRID_OP_H_
+#define TVM_TOP_OPERATION_HYBRID_OP_H_
+
+#include <tvm/expr.h>
+#include <tvm/top/schedule.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "../schedule/message_passing.h"
+#include "../../pass/ir_util.h"
+#include "../../pass/arg_binder.h"
+
+namespace tvm {
+namespace top {
+
+/*!
+ * \brief Find all the iteration variables in the given statement body.
+ * \param stmt The body to be inspected.
+ */
+std::vector<IterVar> GatherLoopVars(Stmt stmt);
+
+/*!
+ * \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
+ * \param stmt The statement to be processed.
+ * \param replace The replacement rule.
+ */
+Stmt ReplaceProvideTensor(Stmt stmt,
+                          const std::unordered_map<Tensor, Tensor>& replace);
+
+/*!
+ * \brief Apply the schedule manipulation on the function body.
+ * \param stmt The statement to be processed.
+ * \param dom_map The extents of the iterative variables may be used.
+ * \param stage The schedule information to be applied.
+ */
+Stmt ApplySchedule(const Stage& stage,
+                   const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
+
+/*!
+ * \brief Apply loop splits and fuses in the schedule on the function body.
+ * \param stage The schedule information to be applied.
+ * \param dom_map The extents of the iterative variables may be used.
+ * \param stmt The statement to be processed.
+ */
+Stmt ApplyLoopShapes(const Stage &stage,
+                     const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
+
+
+/*!
+ * \brief Apply loop annotation in the schedule on the function body.
+ * \param stage The schedule information to be applied.
+ * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
+ * \param stmt The statement to be processed.
+ */
+Stmt ApplyLoopAnnotations(const Stage &stage,
+                          const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt);
+
+/*!
+ * \brief Apply loop order in the schedule on the function body.
+ * \param stage The schedule information to be applied.
+ * \param dom_map The extents of the iterative variables may be used.
+ * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
+ * \param stmt The statement to be processed.
+ */
+Stmt ApplyLoopOrder(const Stage &stage,
+                    const std::unordered_map<IterVar, Range> &dom_map,
+                    const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt);
+
+}  // namespace top
+}  // namespace tvm
+
+#endif  // TVM_TOP_OPERATION_HYBRID_OP_H_
diff --git a/src/top/operation/op_util.cc b/src/top/operation/op_util.cc
new file mode 100644 (file)
index 0000000..fcf8318
--- /dev/null
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Utility to make loop nest.
+ * \file op_util.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/top/operation.h>
+#include <string>
+#include "op_util.h"
+#include "../schedule/message_passing.h"
+#include "../../arith/compute_expr.h"
+
+namespace tvm {
+namespace top {
+
+using namespace arith;
+using namespace ir;
+
+std::vector<std::vector<Stmt> >
+MakeLoopNest(const Stage& stage,
+             const std::unordered_map<IterVar, Range>& dom_map,
+             size_t begin_iter_pos,
+             bool new_loop_var,
+             const std::unordered_set<IterVar>& skip_iter,
+             std::unordered_map<IterVar, PrimExpr>* p_value_map,
+             bool debug_keep_trivial_loop) {
+  auto leaf_iter_vars = stage->leaf_iter_vars;
+  Stmt no_op = EvaluateNode::make(0);
+  // create the loop nest
+  std::vector<std::vector<Stmt> > nest;
+  nest.resize(leaf_iter_vars.size() + 1);
+  std::unordered_map<IterVar, PrimExpr>& value_map = *p_value_map;
+
+  for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
+    auto iv = leaf_iter_vars[i];
+    if (skip_iter.count(iv) || iv->iter_type == kOpaque) {
+      // skip this iteration.
+      value_map[iv] = iv->var;
+      continue;
+    }
+    // Bind iv could be another thread.
+    IterVar bind_iv = iv;
+    if (stage->iter_var_attrs.count(iv)) {
+      IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
+      if (bind_thread.defined()) bind_iv = bind_thread;
+    }
+
+    Range dom = dom_map.at(iv);
+
+    // initialize the offset and loop_level
+    Var var = bind_iv->var;
+
+    // Mark the iter var in the IR, to remember the point
+    if (bind_iv->thread_tag.length() == 0) {
+      // Only generate new loop if we're not bound to a thread.
+      if (new_loop_var) {
+        var = Var(iv->var->name_hint + ".init", bind_iv->var.dtype());
+      }
+
+      ForType for_type = ForType::Serial;
+      IterVarAttr it_attr;
+      if (stage->iter_var_attrs.count(iv)) {
+        it_attr = stage->iter_var_attrs[iv];
+      }
+      if (it_attr.defined()) {
+        switch (it_attr->iter_type) {
+          case kUnrolled: for_type = ForType::Unrolled; break;
+          case kVectorized: for_type = ForType::Vectorized; break;
+          case kParallelized: for_type = ForType::Parallel; break;
+          case kDataPar: break;
+          case kTensorized: break;
+          default: LOG(FATAL) << "Unknown iter type"
+                              << it_attr->iter_type
+                              << " in the iter_var_attrs";
+        }
+        CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
+        for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
+          const std::string& pkey = it_attr->pragma_keys[k].as<StringImmNode>()->value;
+          PrimExpr pvalue = it_attr->pragma_values[k];
+          if (!pvalue.defined()) {
+            pvalue = make_const(DataType::Int(32), 1);
+          }
+          nest[i + 1].emplace_back(
+              AttrStmtNode::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
+        }
+      }
+      if (!debug_keep_trivial_loop && is_one(dom->extent)) {
+        nest[i + 1].emplace_back(
+            LetStmtNode::make(var, dom->min, no_op));
+        value_map[iv] = dom->min;
+      } else if (is_zero(dom->min)) {
+        nest[i + 1].emplace_back(
+            ForNode::make(var, 0, dom->extent,
+                      for_type, DeviceAPI::None, no_op));
+        value_map[iv] = var;
+      } else {
+        Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
+        nest[i + 1].emplace_back(
+            ForNode::make(idx, 0, dom->extent,
+                      for_type, DeviceAPI::None, no_op));
+        PrimExpr new_value = dom->min + idx;
+        value_map[iv] = new_value;
+        nest[i + 1].emplace_back(
+            LetStmtNode::make(var, new_value, no_op));
+      }
+      if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
+        CHECK(!is_one(dom->extent))
+            << "Cannot prefetch on trivial loop with extent=1";
+        CHECK_EQ(it_attr->prefetch_data.size(),
+                 it_attr->prefetch_offset.size());
+        for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
+          nest[i + 1].emplace_back(
+              AttrStmtNode::make(it_attr->prefetch_data[j],
+                             ir::attr::prefetch_scope,
+                             it_attr->prefetch_offset[j], no_op));
+        }
+      }
+    } else if (bind_iv->thread_tag == "vthread" ||
+               bind_iv->thread_tag == "cthread") {
+      // virtual thread
+      // Always restrict threaded IterVar to starts from 0.
+      CHECK(is_zero(dom->min));
+      CHECK(is_positive_const(dom->extent));
+      // annotate the extent of the IterVar
+      nest[i + 1].emplace_back(
+          AttrStmtNode::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
+      value_map[iv] = var;
+    } else if (bind_iv->thread_tag == "pipeline") {
+      // pipeline marker.
+      CHECK(is_zero(dom->min));
+      CHECK(is_one(dom->extent));
+      // annotate the extent of the IterVar
+      nest[i + 1].emplace_back(
+          AttrStmtNode::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
+      value_map[iv] = dom->min;
+    } else {
+      // Always restrict threaded IterVar to starts from 0.
+      CHECK(is_zero(dom->min));
+      // annotate the extent of the IterVar
+      nest[i + 1].emplace_back(
+          AttrStmtNode::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
+      if (!debug_keep_trivial_loop && is_one(dom->extent)) {
+        value_map[iv] = dom->min;
+      } else {
+        value_map[iv] = var;
+      }
+    }
+    // annotate the extent of the IterVar
+    if (!new_loop_var) {
+      nest[i + 1].emplace_back(
+          AttrStmtNode::make(iv, attr::loop_scope, iv->var, no_op));
+    }
+  }
+  // message passing to get offset of root iter vars.
+  top::PassUpIndex(stage, dom_map, &value_map);
+  return nest;
+}
+
+std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates) {
+  Stmt no_op = EvaluateNode::make(0);
+  std::vector<Stmt> nest;
+  for (const PrimExpr& cond : predicates) {
+    nest.emplace_back(IfThenElseNode::make(cond, no_op));
+  }
+  return nest;
+}
+
+// replacer to replace tensors
+class TensorReplacer : public ir::StmtExprMutator {
+ public:
+  explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
+      : vmap_(vmap) {}
+
+  PrimExpr VisitExpr_(const ir::CallNode* op) final {
+    if (op->call_type == ir::CallNode::Halide) {
+      Tensor t = Downcast<Operation>(op->func).output(op->value_index);
+      auto it = vmap_.find(t);
+      if (it != vmap_.end()) {
+        PrimExpr ret = ir::CallNode::make(
+            op->dtype, it->second->op->name, op->args,
+            op->call_type, it->second->op, it->second->value_index);
+        found = true;
+        return this->VisitExpr(ret);
+      }
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
+  // whether it is found.
+  bool found{false};
+
+ private:
+  const std::unordered_map<Tensor, Tensor>& vmap_;
+};
+
+Stmt ReplaceTensor(Stmt stmt,
+                   const std::unordered_map<Tensor, Tensor>& replace) {
+  TensorReplacer repl(replace);
+  Stmt ret = repl(stmt);
+  return repl.found ? ret : stmt;
+}
+PrimExpr ReplaceTensor(PrimExpr expr,
+                       const std::unordered_map<Tensor, Tensor>& replace) {
+  TensorReplacer repl(replace);
+  PrimExpr ret = repl(expr);
+  return repl.found ? ret : expr;
+}
+
+
+Stmt Substitute(Stmt s,
+                const std::unordered_map<IterVar, PrimExpr>& value_map) {
+  std::unordered_map<const VarNode*, PrimExpr> init;
+  for (const auto& kv : value_map) {
+    init[kv.first->var.get()] = kv.second;
+  }
+  return ir::Substitute(s, init);
+}
+
+IterVarType ForTypeToIterVarType(ir::ForType for_type) {
+  switch (for_type) {
+  case ForType::Serial:
+    return kDataPar;
+  case ForType::Parallel:
+    return kParallelized;
+  case ForType::Vectorized:
+    return kVectorized;
+  case ForType::Unrolled:
+    return kUnrolled;
+  default:
+    return kDataPar;
+  }
+}
+
+ir::ForType IterVarTypeToForType(IterVarType iter_type) {
+  switch (iter_type) {
+  case kDataPar:
+    return ForType::Serial;
+  case kParallelized:
+    return ForType::Parallel;
+  case kVectorized:
+    return ForType::Vectorized;
+  case kUnrolled:
+    return ForType::Unrolled;
+  default:
+    return ForType::Serial;
+  }
+}
+
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/operation/op_util.h b/src/top/operation/op_util.h
new file mode 100644 (file)
index 0000000..babdabc
--- /dev/null
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file op_util.h
+ * \brief Common utility used in operator construction.
+ */
+#ifndef TVM_TOP_OPERATION_OP_UTIL_H_
+#define TVM_TOP_OPERATION_OP_UTIL_H_
+
+#include <tvm/expr.h>
+#include <tvm/top/schedule.h>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include "../../pass/ir_util.h"
+#include "../../pass/arg_binder.h"
+#include "../schedule/message_passing.h"
+
+namespace tvm {
+namespace top {
+
+using ir::MergeNest;
+
+/*!
+ * \brief Build loop nest for stage.
+ *
+ * \param stage The stage to create a loop nest.
+ * \param dom_map The range of each iter var.
+ * \param begin_iter_pos The beginning position of leaf_iter_vars to generate loop.
+ * \param new_loop_var Whether create new loop variable.
+ * \param skip_iter Whether skip certain iteration.
+ * \param p_value_map The result value of each IterVar.
+ * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
+ */
+std::vector<std::vector<Stmt> >
+MakeLoopNest(const Stage& stage,
+             const std::unordered_map<IterVar, Range>& dom_map,
+             size_t begin_iter_pos,
+             bool new_loop_var,
+             const std::unordered_set<IterVar>& skip_iter,
+             std::unordered_map<IterVar, PrimExpr>* p_value_map,
+             bool debug_keep_trivial_loop);
+
+/*!
+ * \brief Create a nest of if checking the predicates.
+ *
+ * \param predicates The predicates to be checked.
+ * \return List of If nest that checks the predicates.
+ */
+std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates);
+
+/*!
+ * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
+ * \param stmt The statement to be processed.
+ * \param replace The replacement rule.
+ */
+Stmt ReplaceTensor(Stmt stmt,
+                   const std::unordered_map<Tensor, Tensor>& replace);
+/*!
+ * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
+ * \param expr The expression to be processed.
+ * \param replace The replacement rule.
+ */
+PrimExpr ReplaceTensor(PrimExpr expr,
+                   const std::unordered_map<Tensor, Tensor>& replace);
+
+/*!
+ * \brief Substitute the variables of stmt by value map.
+ * \param stmt the statment
+ * \param value_map The value map.
+ * \return Substituted result.
+ */
+Stmt Substitute(Stmt stmt,
+                const std::unordered_map<IterVar, PrimExpr>& value_map);
+
+/*!
+ * \brief Converts Halide ForType to its corresponding IterVarType
+ * \param for_type The ForType to be converted
+ */
+IterVarType ForTypeToIterVarType(ir::ForType for_type);
+
+/*!
+ * \brief Converts IterVarType to its corresponding Halide ForType
+ * \param iter_type The IterVarType to be converted
+ */
+ir::ForType IterVarTypeToForType(IterVarType iter_type);
+
+}  // namespace top
+}  // namespace tvm
+#endif  // TVM_TOP_OPERATION_OP_UTIL_H_
diff --git a/src/top/operation/placeholder_op.cc b/src/top/operation/placeholder_op.cc
new file mode 100644 (file)
index 0000000..284752b
--- /dev/null
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Placeholder op.
+ * \file placeholder_op.cc
+ */
+#include <tvm/top/operation.h>
+
+namespace tvm {
+namespace top {
+
+// PlaceholderOpNode
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const PlaceholderOpNode*>(node.get());
+    p->stream << "placeholder(" << op->name << ", " << op << ")";
+});
+
+TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
+
+int PlaceholderOpNode::num_outputs() const {
+  return 1;
+}
+
+Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
+  return {};
+}
+
+DataType PlaceholderOpNode::output_dtype(size_t i) const {
+  CHECK_EQ(i, 0U);
+  return dtype;
+}
+
+Array<PrimExpr> PlaceholderOpNode::output_shape(size_t i) const {
+  CHECK_EQ(i, 0U);
+  return shape;
+}
+
+Operation PlaceholderOpNode::make(std::string name,
+                                  Array<PrimExpr> shape,
+                                  DataType dtype) {
+  auto n = make_object<PlaceholderOpNode>();
+  n->name = name;
+  n->shape = shape;
+  n->dtype = dtype;
+  return Operation(n);
+}
+
+Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
+  return PlaceholderOpNode::make(name, shape, dtype).output(0);
+}
+
+Array<Tensor> PlaceholderOpNode::InputTensors() const {
+  return {};
+}
+
+Operation PlaceholderOpNode::ReplaceInputs(
+    const Operation& self,
+    const std::unordered_map<Tensor, Tensor>& rmap) const {
+  return self;
+}
+
+void PlaceholderOpNode::PropBoundToInputs(
+    const Operation& self,
+    arith::Analyzer* analyzer,
+    const std::unordered_map<const VarNode*, IntSet>& dom_map,
+    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+}
+
+void PlaceholderOpNode::GatherBound(
+    const Operation& self,
+    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+    std::unordered_map<IterVar, Range>* out_dom_map) const {
+}
+
+Stmt PlaceholderOpNode::BuildRealize(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& realize_map,
+    const Stmt& body) const {
+  return body;
+}
+
+Stmt PlaceholderOpNode::BuildProvide(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    bool debug_keep_trivial_loop) const {
+  return Stmt();
+}
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/operation/scan_op.cc b/src/top/operation/scan_op.cc
new file mode 100644 (file)
index 0000000..8f54872
--- /dev/null
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Scan Operator.
+ * \file scan_op.cc
+ */
+#include <tvm/top/operation.h>
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include "op_util.h"
+#include "../schedule/graph.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<ScanOpNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const ScanOpNode*>(node.get());
+    p->stream << "scan(" << op->name << ", " << op << ")";
+});
+TVM_REGISTER_NODE_TYPE(ScanOpNode);
+
+inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) {
+  return is_zero(ir::Simplify(lhs - rhs));
+}
+
+int ScanOpNode::num_outputs() const {
+  return static_cast<int>(update.size());
+}
+Array<IterVar> ScanOpNode::root_iter_vars() const {
+  Array<IterVar> ret{scan_axis};
+  for (IterVar iv : spatial_axis_) {
+    ret.push_back(iv);
+  }
+  return ret;
+}
+
+DataType ScanOpNode::output_dtype(size_t i) const {
+  return update[i]->dtype;
+}
+
+Array<PrimExpr> ScanOpNode::output_shape(size_t i) const {
+  CHECK_LT(i, state_placeholder.size());
+  return state_placeholder[i]->shape;
+}
+
+Operation ScanOpNode::make(std::string name,
+                           std::string tag,
+                           Map<std::string, ObjectRef> attrs,
+                           IterVar axis,
+                           Array<Tensor> init,
+                           Array<Tensor> update,
+                           Array<Tensor> state_placeholder,
+                           Array<Tensor> inputs) {
+  if (!attrs.defined()) {
+    attrs = Map<std::string, ObjectRef>();
+  }
+  auto n = make_object<ScanOpNode>();
+  CHECK_EQ(init.size(), update.size());
+  CHECK_EQ(init.size(), state_placeholder.size());
+
+  for (size_t i = 0; i < init.size(); ++i) {
+    CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
+    CHECK_EQ(init[i]->dtype, update[i]->dtype);
+    CHECK(prove_equal(init[i]->shape[0], axis->dom->min))
+        << "init.shape[0] need to match scan_axis.dom.min";
+    CHECK(prove_equal(
+        state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
+        << "state_placeholder.shape[0] need to match"
+        << " scan_axis.dom.min + scan_axis.dom.extent";
+    CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
+        << "The dimension of init need to match state_placeholder";
+    CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
+        << "The update.ndim need to be state_placeholder.ndim - 1";
+    for (size_t k = 0;  k < update[i].ndim(); ++k) {
+      CHECK(prove_equal(
+          update[i]->shape[k], state_placeholder[i]->shape[k]));
+      if (k != 0) {
+        // setup spatial axis
+        std::ostringstream spatial_name;
+        spatial_name << name << ".out" << i << ".i" << k;
+        n->spatial_axis_.push_back(
+            IterVarNode::make(
+                Range::make_by_min_extent(0, update[i]->shape[k]),
+                Var(spatial_name.str()), kOpaque));
+      }
+    }
+
+    for (size_t k = 1;  k < init[i].ndim(); ++k) {
+      CHECK(prove_equal(
+          init[i]->shape[k], state_placeholder[i]->shape[k]));
+    }
+  }
+  n->name = std::move(name);
+  n->tag = std::move(tag);
+  n->attrs = std::move(attrs);
+  n->scan_axis = std::move(axis);
+  n->init = std::move(init);
+  n->update = std::move(update);
+  n->state_placeholder = std::move(state_placeholder);
+  n->inputs = std::move(inputs);
+  return Operation(n);
+}
+
+Array<Tensor> scan(Array<Tensor> init,
+                   Array<Tensor> update,
+                   Array<Tensor> state_placeholder,
+                   Array<Tensor> inputs,
+                   std::string name,
+                   std::string tag,
+                   Map<std::string, ObjectRef> attrs) {
+  IterVar scan_axis =
+      IterVarNode::make(
+          Range::make_by_min_extent(
+              init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
+          Var(name + ".idx"), kOrdered);
+  Operation op = ScanOpNode::make(
+      name, tag, attrs, scan_axis,
+      init, update, state_placeholder, inputs);
+  Array<Tensor> res;
+  for (int i = 0; i < op->num_outputs(); ++i) {
+    res.push_back(op.output(i));
+  }
+  return res;
+}
+
+Array<Tensor> ScanOpNode::InputTensors() const {
+  Array<Tensor> ret;
+  for (Tensor t : init) {
+    ret.push_back(t);
+  }
+  for (Tensor t : update) {
+    ret.push_back(t);
+  }
+  return ret;
+}
+
+Operation ScanOpNode::ReplaceInputs(
+    const Operation& self,
+    const std::unordered_map<Tensor, Tensor>& rmap) const {
+  CHECK_EQ(self.operator->(), this);
+  auto n = make_object<ScanOpNode>(*this);
+  for (size_t i = 0; i < n->init.size(); ++i) {
+    if (rmap.count(n->init[i])) {
+      n->init.Set(i, rmap.at(n->init[i]));
+    }
+    if (rmap.count(n->update[i])) {
+      n->update.Set(i, rmap.at(n->update[i]));
+    }
+  }
+  if (!n->init.same_as(init) ||
+      !n->update.same_as(update)) {
+    return Operation(n);
+  } else {
+    return self;
+  }
+}
+
+void ScanOpNode::PropBoundToInputs(
+    const Operation& self,
+    arith::Analyzer* analyzer,
+    const std::unordered_map<const VarNode*, IntSet>& dom_map,
+    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+  CHECK_EQ(self.operator->(), this);
+  for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) {
+    TensorDom* init_dom = nullptr;
+    TensorDom* update_dom = nullptr;
+    if (out_dom_map->count(this->init[i])) {
+      init_dom = &out_dom_map->at(this->init[i]);
+    }
+    if (out_dom_map->count(this->update[i])) {
+      update_dom = &out_dom_map->at(this->update[i]);
+    }
+    // first dimension, always needed.
+    if (init_dom) {
+      init_dom->data[0].push_back(IntSet::range(
+          Range::make_by_min_extent(0, this->init[i]->shape[0])));
+    }
+    if (update_dom) {
+      update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
+    }
+    // The update dimensions
+    for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
+      IterVar sp_ax = this->spatial_axis_[sp_idx];
+      if (init_dom) {
+        init_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
+      }
+      if (update_dom) {
+        update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
+      }
+    }
+  }
+}
+
+void ScanOpNode::GatherBound(
+    const Operation& self,
+    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+    std::unordered_map<IterVar, Range>* out_dom_map) const {
+  CHECK_EQ(self.operator->(), this);
+  CHECK(!out_dom_map->count(this->scan_axis));
+  std::vector<Tensor> output(this->num_outputs());
+  for (size_t i = 0; i < output.size(); ++i) {
+    output[i] = self.output(i);
+  }
+  // Update for time axis.
+  std::vector<IntSet> time_dom;
+  for (size_t i = 0; i < output.size(); ++i) {
+    const TensorDom& d = tensor_dom.at(output[i]);
+    time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
+  }
+  CHECK(!out_dom_map->count(this->scan_axis));
+  Range sdom = this->scan_axis->dom;
+  Range r = arith::Union(time_dom).cover_range(sdom);
+  (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
+      sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
+  Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self);
+  // Update for spatial axis.
+  size_t sp_idx = 0;
+  for (size_t i = 0; i < output.size(); ++i) {
+    const TensorDom& d = tensor_dom.at(output[i]);
+    for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
+      IterVar sp_ax = this->spatial_axis_[sp_idx];
+      CHECK(!out_dom_map->count(sp_ax));
+      CHECK(fix_pt.count(sp_ax));
+      if (fix_pt[sp_ax].as<ir::IntImmNode>()->value) {
+        // fix point, we can slice it.
+        (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom);
+      } else {
+        // not a fix point, need to include everything.
+        (*out_dom_map)[sp_ax] = sp_ax->dom;
+      }
+    }
+  }
+}
+
+Stmt ScanOpNode::BuildRealize(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    const Stmt& body) const {
+  CHECK_EQ(stage->op.get(), this);
+  Range sdom = dom_map.at(this->scan_axis);
+  Range tdom = Range::make_by_min_extent(
+      0, ir::Simplify(sdom->extent + sdom->min));
+  Stmt ret = body;
+  size_t sp_idx = 0;
+  for (size_t i = 0; i < update.size(); ++i) {
+    Tensor t = stage->op.output(i);
+    CHECK_EQ(static_cast<size_t>(t->value_index), i);
+    Region bounds;
+    bounds.push_back(tdom);
+    for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
+      IterVar sp_ax = this->spatial_axis_[sp_idx];
+      bounds.push_back(dom_map.at(sp_ax));
+    }
+    ret = ir::RealizeNode::make(t->op, t->value_index, t->dtype,
+                            bounds, const_true(), ret);
+  }
+  return ret;
+}
+
+Stmt ScanOpNode::BuildProvide(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    bool debug_keep_trivial_loop) const {
+  CHECK_EQ(stage->op.operator->(), this);
+  Stmt provide = AttrStmtNode::make(
+      stage->op, attr::scan_update_scope, this->scan_axis->var,
+      EvaluateNode::make(0));
+  Stmt init = AttrStmtNode::make(
+      stage->op, attr::scan_init_scope, 0,
+      EvaluateNode::make(0));
+  size_t begin_scan = 0;
+  for (size_t  i = 0; i < stage->leaf_iter_vars.size(); ++i) {
+    if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
+      CHECK_EQ(begin_scan, i);
+      begin_scan = i + 1;
+    }
+  }
+  std::unordered_map<IterVar, PrimExpr> vmap;
+  std::unordered_set<IterVar> empty;
+  auto nest = MakeLoopNest(
+      stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
+  nest[begin_scan].push_back(init);
+  nest.push_back(
+      MakeIfNest(
+          MakeBoundCheck(stage, dom_map, vmap, false, empty)));
+  return MergeNest(nest, provide);
+}
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/operation/tensor_compute_op.cc b/src/top/operation/tensor_compute_op.cc
new file mode 100644 (file)
index 0000000..49b00fc
--- /dev/null
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Tensor Compute Op.
+ * \file tensor_compute_op.cc
+ */
+#include <tvm/top/operation.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <unordered_set>
+#include "./op_util.h"
+#include "./compute_op.h"
+#include "../../arith/compute_expr.h"
+
+namespace tvm {
+namespace top {
+using namespace ir;
+// TensorComputeOpNode
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const TensorComputeOpNode*>(node.get());
+    p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
+  });
+
+TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);
+
+int TensorComputeOpNode::num_outputs() const {
+  return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
+}
+
+DataType TensorComputeOpNode::output_dtype(size_t i) const {
+  return this->intrin->buffers[this->inputs.size() + i]->dtype;
+}
+
+Operation TensorComputeOpNode::make(std::string name,
+                                    std::string tag,
+                                    Array<IterVar> axis,
+                                    Array<IterVar> reduce_axis,
+                                    int schedulable_ndim,
+                                    TensorIntrin intrin,
+                                    Array<Tensor> tensors,
+                                    Array<Region> regions,
+                                    Array<PrimExpr> scalar_inputs) {
+  auto n = make_object<TensorComputeOpNode>();
+  n->name = std::move(name);
+  n->tag = std::move(tag);
+  n->axis = std::move(axis);
+  n->reduce_axis = std::move(reduce_axis);
+  n->schedulable_ndim = std::move(schedulable_ndim);
+  n->intrin = std::move(intrin);
+  n->inputs = std::move(tensors);
+  n->input_regions = std::move(regions);
+  n->scalar_inputs = std::move(scalar_inputs);
+  return Operation(n);
+}
+
+Array<Tensor> TensorComputeOpNode::InputTensors() const {
+  return inputs;
+}
+
+Operation TensorComputeOpNode::ReplaceInputs(
+    const Operation& self,
+    const std::unordered_map<Tensor, Tensor>& rmap) const {
+  CHECK_EQ(self.operator->(), this);
+  auto n = make_object<TensorComputeOpNode>(*this);
+  auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
+  intrin->body = ReplaceTensor(this->intrin->body, rmap);
+  if (intrin->reduce_init.defined()) {
+    intrin->reduce_init = ReplaceTensor(this->intrin->reduce_init, rmap);
+  }
+  if (intrin->reduce_update.defined()) {
+    intrin->reduce_update = ReplaceTensor(this->intrin->reduce_update, rmap);
+  }
+  for (size_t i = 0; i < n->inputs.size(); ++i) {
+    Tensor t = n->inputs[i];
+    if (rmap.count(t)) {
+      n->inputs.Set(i, rmap.at(t));
+    }
+  }
+
+  if (intrin->body.same_as(n->intrin->body) &&
+      intrin->reduce_init.same_as(n->intrin->reduce_init) &&
+      intrin->reduce_update.same_as(n->intrin->reduce_update) &&
+      inputs.same_as(n->inputs)) {
+    return self;
+  } else {
+    n->intrin = TensorIntrin(intrin);
+    return Operation(n);
+  }
+}
+
+void TensorComputeOpNode::PropBoundToInputs(
+    const Operation& self,
+    arith::Analyzer* analyzer,
+    const std::unordered_map<const VarNode*, IntSet>& dom_map,
+    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+  for (size_t i = 0; i < this->inputs.size(); ++i) {
+    Tensor t = this->inputs[i];
+    Region region = input_regions[i];
+
+    auto it = out_dom_map->find(t);
+    if (it == out_dom_map->end()) continue;
+    TensorDom& dom = it->second;
+    for (size_t j = 0; j < t.ndim(); ++j) {
+      dom.data[j].emplace_back(EvalSet(region[j], dom_map));
+    }
+  }
+}
+
+size_t TensorComputeOpNode::num_schedulable_dims() const {
+  return schedulable_ndim;
+}
+
+Stmt TensorComputeOpNode::BuildProvide(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    bool debug_keep_trivial_loop) const {
+  CHECK_EQ(stage->op.operator->(), this);
+
+  // Start bind data.
+  Stmt nop = EvaluateNode::make(0);
+  std::vector<Stmt> input_bind_nest, output_bind_nest;
+  Array<Tensor> inputs = this->InputTensors();
+
+  // input binding
+  size_t num_inputs = inputs.size();
+  for (size_t i = 0; i < num_inputs; ++i) {
+    Tensor tensor = inputs[i];
+    Region region = this->input_regions[i];
+    Buffer buffer = this->intrin->buffers[i];
+    Array<ObjectRef> bind_spec{buffer, tensor};
+
+    Array<PrimExpr> tuple;
+    for (size_t i = 0; i < region.size(); ++i) {
+      tuple.push_back(region[i]->min);
+      tuple.push_back(region[i]->extent);
+    }
+    input_bind_nest.emplace_back(AttrStmtNode::make(
+        bind_spec, ir::attr::buffer_bind_scope,
+        CallNode::make(DataType::Handle(),
+                       ir::intrinsic::tvm_tuple,
+                       tuple, CallNode::Intrinsic), nop));
+  }
+
+  // output binding
+  for (int i = 0; i < this->num_outputs(); ++i) {
+    Tensor tensor = stage->op.output(i);
+    Buffer buffer = this->intrin->buffers[num_inputs + i];
+    Array<ObjectRef> bind_spec{buffer, tensor};
+
+    Array<PrimExpr> tuple;
+    for (size_t i = 0; i < this->axis.size(); ++i) {
+      auto ivar = this->axis[i];
+      if (i < static_cast<size_t>(this->schedulable_ndim)) {
+        tuple.push_back(ivar->var);
+        tuple.push_back(1);
+      } else {
+        Range dom = ivar->dom;
+        tuple.push_back(dom->min);
+        tuple.push_back(dom->extent);
+      }
+    }
+
+    output_bind_nest.emplace_back(AttrStmtNode::make(
+        bind_spec, ir::attr::buffer_bind_scope,
+        CallNode::make(DataType::Handle(),
+                       ir::intrinsic::tvm_tuple,
+                       tuple, CallNode::Intrinsic), nop));
+  }
+
+  // Check variable remap
+  std::unordered_map<const VarNode*, PrimExpr> vmap;
+  ir::ArgBinder binder(&vmap);
+
+  // Map the expressions passed in the call to the TensorIntrin, to the placeholder
+  // variables
+  Array<PrimExpr> user_expr = this->scalar_inputs;
+  Array<Var> scalar_params = this->intrin->scalar_params;
+  Array<PrimExpr> sp_expr;
+  for (auto sp : scalar_params) {
+    PrimExpr esp = sp;
+    sp_expr.push_back(esp);
+  }
+  CHECK_EQ(sp_expr.size(), user_expr.size());
+  // TODO(jdavies-huawei): what name should be used here?
+  binder.BindArray(sp_expr, user_expr, this->name);
+
+  size_t tloc = stage->leaf_iter_vars.size();
+  ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
+
+  if (this->reduce_axis.size() == 0) {
+    std::vector<std::vector<Stmt> > nest(
+        n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
+    nest.emplace_back(MakeIfNest(n.main_predicates));
+    CHECK_EQ(n.init_predicates.size(), 0U);
+    CHECK(this->intrin->body.defined())
+        << "Normal store op for intrin " << this << " is not defined";
+    Stmt body = MergeNest(output_bind_nest, this->intrin->body);
+    body = MergeNest(input_bind_nest, body);
+    body = ir::Substitute(body, vmap);
+    body = MergeNest(binder.asserts(), body);
+    body = top::Substitute(body, n.main_vmap);
+    Stmt ret =  MergeNest(nest, body);
+    return ret;
+  } else {
+    // Need to split reduction
+    CHECK(this->intrin->reduce_update.defined())
+        << "Reduction update op is not defined";
+    // Need init and update steps
+    CHECK_NE(this->reduce_axis.size(), 0U);
+    std::vector<std::vector<Stmt> > common(
+        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
+    std::vector<std::vector<Stmt> > update_nest(
+        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
+    update_nest.emplace_back(MakeIfNest(n.main_predicates));
+
+    if (this->intrin->reduce_init.defined()) {
+      // init nest
+      std::vector<std::vector<Stmt> > init_nest(
+          n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
+      init_nest.emplace_back(MakeIfNest(n.init_predicates));
+      Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
+      init = top::Substitute(init, n.init_vmap);
+      init = MergeNest(init_nest, init);
+      // The update
+      Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
+      update = MergeNest(input_bind_nest, update);
+      update = ir::Substitute(update, vmap);
+      update = MergeNest(binder.asserts(), update);
+      update = top::Substitute(update, n.main_vmap);
+      update = MergeNest(update_nest, update);
+      return MergeNest(common, SeqStmt::Flatten(init, update));
+    } else {
+      // When init op is not available, use body op for reset in the first iter.
+      CHECK(this->intrin->body.defined())
+          << "Normal body op is not defined";
+      Stmt update = TransformUpdate(stage, dom_map, n,
+                                    this->intrin->body,
+                                    this->intrin->reduce_update);
+      update = MergeNest(output_bind_nest, update);
+      update = MergeNest(input_bind_nest, update);
+      update = ir::Substitute(update, vmap);
+      update = MergeNest(binder.asserts(), update);
+      update = top::Substitute(update, n.main_vmap);
+      update = MergeNest(update_nest, update);
+      return MergeNest(common, update);
+    }
+  }
+}
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/operation/tensorize.cc b/src/top/operation/tensorize.cc
new file mode 100644 (file)
index 0000000..413bb42
--- /dev/null
@@ -0,0 +1,537 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Logics related to tensorize, used by ComputeOpNode.
+ * \file tensorize.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/ir_pass.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/packed_func_ext.h>
+
+#include "op_util.h"
+#include "compute_op.h"
+#include "../schedule/message_passing.h"
+
+namespace tvm {
+namespace top {
+
+using namespace ir;
+
+// Detect the region of input and output to be tensrized.
+// out_dom: the domain of root iter vars in output op
+// in_region: region of each input tensor.
+// return The location of the tensorized scope start.
+size_t InferTensorizeRegion(
+    const ComputeOpNode* self,
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    std::unordered_map<IterVar, Range>* out_dom,
+    std::unordered_map<Tensor, Array<Range> >* in_region) {
+  // Get the bound of the tensorized scope.
+  bool found_point = false;
+  size_t loc_scope = 0;
+  std::unordered_map<IterVar, IntSet> up_state;
+  // Loop over the leafs
+  for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) {
+    IterVar iv = stage->leaf_iter_vars[i - 1];
+    CHECK(iv->iter_type == kDataPar ||
+          iv->iter_type == kCommReduce);
+    auto vit = dom_map.find(iv);
+    CHECK(vit != dom_map.end());
+    const Range& vrange = vit->second;
+    if (is_one(vrange->extent)) {
+      up_state[iv] = IntSet::single_point(vrange->min);
+    } else if (found_point) {
+      CHECK(is_zero(vrange->min));
+      up_state[iv] = IntSet::single_point(iv->var);
+    } else {
+      up_state[iv] = IntSet::range(vrange);
+    }
+    auto iit = stage->iter_var_attrs.find(iv);
+    if (iit != stage->iter_var_attrs.end()) {
+      const IterVarAttr& attr = (*iit).second;
+      if (!found_point) {
+        CHECK(!attr->bind_thread.defined())
+            << "Do not allow thread in tensorize scope";
+      }
+      if (attr->iter_type == kTensorized) {
+        CHECK(!found_point) << "Do not allow two tensorized point";
+        found_point = true;
+        loc_scope = i - 1;
+      }
+    }
+  }
+  CHECK(found_point);
+  // Get domain of the tensorized scope.
+  top::PassUpDomain(stage, dom_map, &up_state);
+  // Get domains if inputs
+  std::unordered_map<Tensor, TensorDom> in_dom;
+  std::unordered_map<const VarNode*, IntSet> temp_dmap;
+  arith::Analyzer analyzer;
+  Array<Tensor> inputs = self->InputTensors();
+  for (Tensor t : inputs) {
+    in_dom.emplace(t, TensorDom(t.ndim()));
+  }
+  for (IterVar iv : self->root_iter_vars()) {
+    IntSet iset = up_state.at(iv);
+    Range iv_range = iset.cover_range(dom_map.at(iv));
+    (*out_dom)[iv] = iv_range;
+    analyzer.Bind(iv->var, iv_range);
+    temp_dmap[iv->var.get()] = iset;
+  }
+  // Input domains
+  self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom);
+  Range none;
+  for (const auto& kv : in_dom) {
+    Array<Range> vec;
+    const Tensor& t = kv.first;
+    for (size_t i = 0; i < t.ndim(); ++i) {
+      Range r = arith::Union(kv.second.data.at(i)).cover_range(none);
+      CHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t;
+      vec.push_back(std::move(r));
+    }
+    (*in_region)[t] = std::move(vec);
+  }
+  return loc_scope;
+}
+
+void VerifyTensorizeLoopNest(const ComputeOpNode* self,
+                             const Stage& stage,
+                             const ComputeLoopNest& n,
+                             size_t tloc) {
+  // Veirfication step.
+  std::unordered_set<const VarNode*> banned;
+  CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
+  CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 ||
+        n.init_nest.size() == 0);
+  auto f_push_banned = [&banned](const Stmt& s) {
+    if (const ForNode* op = s.as<ForNode>()) {
+        banned.insert(op->loop_var.get());
+    } else if (const AttrStmtNode* op = s.as<AttrStmtNode>()) {
+      if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
+        banned.insert(iv->var.get());
+      }
+    } else if (const LetStmtNode* op = s.as<LetStmtNode>()) {
+      banned.insert(op->var.get());
+    }
+  };
+  for (size_t i = tloc; i < stage->leaf_iter_vars.size(); ++i) {
+    for (const Stmt& s : n.main_nest[i + 1]) {
+      f_push_banned(s);
+    }
+    if (n.init_nest.size() != 0) {
+      for (const Stmt& s : n.init_nest[i + 1]) {
+        f_push_banned(s);
+      }
+    }
+  }
+  for (const PrimExpr& pred : n.main_predicates) {
+    if (ir::ExprUseVar(pred, banned)) {
+      LOG(FATAL) << "Tensorize failed, split condition "
+                 << pred << " relies on var defined inside tensorize scope";
+    }
+  }
+  for (const PrimExpr& pred : n.init_predicates) {
+    if (ir::ExprUseVar(pred, banned)) {
+      LOG(FATAL) << "Tensorize failed, split condition "
+                 << pred << " relies on var defined inside tensorize scope";
+    }
+  }
+}
+
+// Remap the tensor placeholder, index and inline things.
+class TensorIntrinMatcher final : public StmtExprMutator {
+ public:
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<CallNode>();
+    if (op->call_type == CallNode::Halide) {
+      Tensor t = Downcast<Operation>(op->func).output(op->value_index);
+      auto it = in_remap_.find(t);
+      if (it != in_remap_.end()) {
+        const InputEntry& e = it->second;
+        CHECK_EQ(op->args.size(), e.region.size());
+        Array<PrimExpr> args;
+        for (size_t i = e.start; i < e.region.size(); ++i) {
+          args.push_back(op->args[i] - e.region[i]->min);
+        }
+        return CallNode::make(
+            op->dtype, e.tensor->op->name, args,
+            op->call_type, e.tensor->op, e.tensor->value_index);
+      }
+    }
+    return expr;
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_remap_.find(op);
+    if (it != var_remap_.end()) {
+      return it->second;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const ReduceNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<ReduceNode>();
+    Array<IterVar> axis;
+    for (size_t i = 0; i < op->axis.size(); ++i) {
+      auto it = axis_remap_.find(op->axis[i]);
+      if (it != axis_remap_.end()) {
+        axis.push_back(it->second);
+      }
+    }
+    return ReduceNode::make(
+        op->combiner, op->source, axis, op->condition, op->value_index);
+  }
+
+  void Init(const ComputeOpNode* self,
+            const Stage& stage,
+            const std::unordered_map<IterVar, Range>& dom_map,
+            const std::unordered_map<IterVar, Range>& out_dom,
+            const std::unordered_map<Tensor, Array<Range> >& in_region,
+            const TensorIntrin& intrin,
+            Map<Var, Range>* compute_intrin_iter_space) {
+    CHECK(self == stage->op.get());
+
+    for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
+      IterVar iv = stage->leaf_iter_vars[i];
+      auto vit = dom_map.find(iv);
+      if (vit != dom_map.end()) {
+        const Range vrange = vit->second;
+        compute_intrin_iter_space->Set(iv->var, vrange);
+      }
+    }
+
+    // input remap.
+    Array<Tensor> inputs = self->InputTensors();
+    CHECK_EQ(inputs.size(), intrin->inputs.size());
+    for (size_t i = 0; i < inputs.size(); ++i) {
+      InputEntry e;
+      e.tensor = intrin->inputs[i];
+      e.region = Array<Range>(in_region.at(inputs[i]));
+      CHECK_GE(e.region.size(), e.tensor.ndim());
+      // Enable fuzzy matching, to match [1, n, m] to [n, m]
+      e.start = e.region.size() - e.tensor.ndim();
+      for (size_t j = 0; j < e.start; ++j) {
+        auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space);
+        CHECK(is_one(canonical_extent))
+            << "Tensorize " << intrin->name << ":"
+            << " Input dimension mismatch with tensor intrin "
+            << " expected shape=" << e.tensor->shape
+            << ", given region=" << e.region;
+      }
+      in_remap_[inputs[i]] = e;
+    }
+    // output remap
+    const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
+    CHECK(intrin_compute) << "Only support compute intrinsic for now";
+    CHECK_GE(self->axis.size(), intrin_compute->axis.size())
+        << "Tensorize: Output mismatch with tensor intrin ";
+    // Enable fuzzy matching, to match [1, n, m] to [n, m]
+    size_t axis_start = self->axis.size() - intrin_compute->axis.size();
+    for (size_t i = 0; i < axis_start; ++i) {
+      Range r = out_dom.at(self->axis[i]);
+      CHECK(is_one(r->extent))
+          << "Tensorize: Output mismatch with tensor intrin "
+          << " intrin-dim=" << intrin_compute->axis.size()
+          << ", tensorize-dim=" << self->axis.size();
+      var_remap_[self->axis[i]->var.get()] = r->min;
+    }
+    // Assume we tensorize at regin axis i [min, min + extent)
+    // The corresponding intrinsic axis is j [0, extent)
+    // Remap index i to j + min
+    for (size_t i = axis_start; i < self->axis.size(); ++i) {
+      IterVar iv = self->axis[i];
+      IterVar target_iv = intrin_compute->axis[i - axis_start];
+      Range r = out_dom.at(iv);
+      var_remap_[iv->var.get()] = target_iv->var + r->min;
+      axis_remap_[iv] = target_iv;
+      compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
+    }
+    // Remap reduction axis
+    CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
+        << "Tensorize: Reduction dimension mismatch with tensor intrin";
+    axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
+    for (size_t i = 0; i < axis_start; ++i) {
+      Range r = out_dom.at(self->reduce_axis[i]);
+      CHECK(is_one(r->extent))
+          << "Tensorize: Reduction mismatch with tensor intrin "
+          << " intrin-dim=" << intrin_compute->reduce_axis.size()
+          << ", tensorize-dim=" << self->reduce_axis.size();
+      var_remap_[self->reduce_axis[i]->var.get()] = r->min;
+    }
+    for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
+      IterVar iv = self->reduce_axis[i];
+      IterVar target_iv = intrin_compute->reduce_axis[i - axis_start];
+      Range r = out_dom.at(iv);
+      var_remap_[iv->var.get()] = target_iv->var + r->min;
+      axis_remap_[iv] = target_iv;
+      compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
+    }
+  }
+
+ private:
+  // Input entry
+  struct InputEntry {
+    Tensor tensor;
+    size_t start;
+    Array<Range> region;
+  };
+  // input data remap
+  std::unordered_map<Tensor, InputEntry> in_remap_;
+  // variable remap.
+  std::unordered_map<const VarNode*, PrimExpr> var_remap_;
+  // IterVar remap.
+  std::unordered_map<IterVar, IterVar> axis_remap_;
+};
+
+// Try to match tensor dataflow of the stage with the intrinsic
+Array<PrimExpr> MatchTensorizeBody(
+    const ComputeOpNode* self,
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    const std::unordered_map<IterVar, Range>& out_dom,
+    const std::unordered_map<Tensor, Array<Range> >& in_region,
+    const TensorIntrin& intrin,
+    Map<Var, Range>* compute_intrin_iter_space) {
+  TensorIntrinMatcher matcher;
+  matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
+  Array<PrimExpr> ret;
+  for (PrimExpr expr : self->body) {
+    ret.push_back(matcher(expr));
+  }
+  return ret;
+}
+
+void VerifyTensorizeBody(
+    const ComputeOpNode* self,
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    const std::unordered_map<IterVar, Range>& out_dom,
+    const std::unordered_map<Tensor, Array<Range> >& in_region,
+    const TensorIntrin& intrin) {
+  Map<Var, Range> compute_intrin_iter_space;
+  Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
+                                        &compute_intrin_iter_space);
+  const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
+  CHECK(intrin_compute) << "Only support compute intrinsic for now";
+  CHECK_EQ(body.size(), intrin_compute->body.size())
+      << "Tensorize failed: body size mismatch";
+  for (size_t i = 0; i < body.size(); ++i) {
+    PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space);
+    lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
+    PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
+    rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
+    if (lhs.dtype() != rhs.dtype()) {
+      LOG(FATAL)
+          << "Failed to match the data type with TensorIntrin "
+          << intrin->name << "'s declaration "
+          << " provided=" << lhs.dtype()
+          << ", intrin=" << rhs.dtype();
+    }
+    CHECK(Equal(lhs, rhs))
+        << "Failed to match the compute with TensorIntrin "
+        << intrin->name << "'s declaration "
+        << " provided= " << lhs
+        << ", intrin=  " << rhs;
+  }
+}
+
+Stmt MakeTensorize(const ComputeOpNode* self,
+                   const Stage& stage,
+                   const std::unordered_map<IterVar, Range>& dom_map,
+                   bool debug_keep_trivial_loop) {
+  std::unordered_map<IterVar, Range> out_dom;
+  std::unordered_map<Tensor, Array<Range> > in_region;
+  size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
+  TensorIntrin intrin = stage->iter_var_attrs.at(
+      stage->leaf_iter_vars[tloc])->tensor_intrin;
+  CHECK(intrin.defined());
+  ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
+  VerifyTensorizeLoopNest(self, stage, n, tloc);
+  VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
+  // Start bind data.
+  Stmt nop = EvaluateNode::make(0);
+  std::vector<Stmt> input_bind_nest, output_bind_nest;
+  Array<Tensor> inputs = self->InputTensors();
+  CHECK_EQ(inputs.size(), intrin->inputs.size())
+      << "Tensorize failed: input size mismatch ";
+  // input binding
+  for (size_t i = 0; i < intrin->inputs.size(); ++i) {
+    Tensor tensor = inputs[i];
+    Buffer buffer = intrin->buffers[i];
+    Array<ObjectRef> bind_spec{buffer, tensor};
+    auto it = in_region.find(tensor);
+    CHECK(it != in_region.end());
+    const Array<Range>& region = it->second;
+    Array<PrimExpr> tuple;
+    for (const Range r : region) {
+      tuple.push_back(r->min);
+      tuple.push_back(r->extent);
+    }
+    input_bind_nest.emplace_back(AttrStmtNode::make(
+        bind_spec, ir::attr::buffer_bind_scope,
+        CallNode::make(DataType::Handle(),
+                       ir::intrinsic::tvm_tuple,
+                       tuple, CallNode::Intrinsic), nop));
+  }
+  // output binding
+  const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
+  CHECK(intrin_compute) << "Only support compute intrinsic for now";
+  CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size());
+  CHECK_EQ(intrin_compute->body.size(), self->body.size());
+  Array<PrimExpr> tuple;
+  for (IterVar iv : self->axis) {
+    auto it = out_dom.find(iv);
+    CHECK(it != out_dom.end());
+    tuple.push_back(it->second->min);
+    tuple.push_back(it->second->extent);
+  }
+  for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) {
+    Tensor tensor = stage->op.output(i - intrin->inputs.size());
+    Buffer buffer = intrin->buffers[i];
+    Array<ObjectRef> bind_spec{buffer, tensor};
+    output_bind_nest.emplace_back(AttrStmtNode::make(
+        bind_spec, ir::attr::buffer_bind_scope,
+        CallNode::make(DataType::Handle(),
+                       ir::intrinsic::tvm_tuple,
+                       tuple, CallNode::Intrinsic), nop));
+  }
+  // Check variable remap
+  std::unordered_map<const VarNode*, PrimExpr> vmap;
+  ir::ArgBinder binder(&vmap);
+  CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
+      << "Tensorization fail: reduction axis size do not match";
+  size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
+  for (size_t i = 0; i < start; ++i) {
+    IterVar iv = self->reduce_axis[i];
+    auto it = out_dom.find(iv);
+    CHECK(it != out_dom.end());
+    CHECK(is_one(it->second->extent))
+        << "Tensorization fail: reduction axis size do not match";
+  }
+  for (size_t i = start; i < self->reduce_axis.size(); ++i) {
+    IterVar iv = self->reduce_axis[i];
+    IterVar target = intrin_compute->reduce_axis[i - start];
+    auto it = out_dom.find(iv);
+    CHECK(it != out_dom.end());
+    binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0),
+                "tensir_intrin.reduction.min");
+    binder.Bind(target->dom->extent, it->second->extent,
+                "tensir_intrin.reduction.extent");
+  }
+  if (tloc <= n.num_common_loop) {
+    // Do no need to split reduction
+    std::vector<std::vector<Stmt> > nest(
+        n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
+    nest.emplace_back(MakeIfNest(n.main_predicates));
+    CHECK_EQ(n.init_predicates.size(), 0U);
+    CHECK(intrin->body.defined())
+        << "Normal store op for intrin " << intrin << " is not defined";
+    Stmt body = MergeNest(output_bind_nest, intrin->body);
+    body = MergeNest(input_bind_nest, body);
+    body = ir::Substitute(body, vmap);
+    body = MergeNest(binder.asserts(), body);
+    body = top::Substitute(body, n.main_vmap);
+    return MergeNest(nest, body);
+  } else {
+    // Need to split reduction
+    CHECK(intrin->reduce_update.defined())
+        << "Reduction update op for intrin " << intrin << " is not defined";
+    // Need init and update steps
+    CHECK_NE(self->reduce_axis.size(), 0U);
+    std::vector<std::vector<Stmt> > common(
+        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
+    std::vector<std::vector<Stmt> > update_nest(
+        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
+    update_nest.emplace_back(MakeIfNest(n.main_predicates));
+
+    if (intrin->reduce_init.defined()) {
+      // init nest
+      std::vector<std::vector<Stmt> > init_nest(
+          n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
+      init_nest.emplace_back(MakeIfNest(n.init_predicates));
+      Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
+      init = top::Substitute(init, n.init_vmap);
+      init = MergeNest(init_nest, init);
+      // The update
+      Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
+      update = MergeNest(input_bind_nest, update);
+      update = ir::Substitute(update, vmap);
+      update = MergeNest(binder.asserts(), update);
+      update = top::Substitute(update, n.main_vmap);
+      update = MergeNest(update_nest, update);
+      return MergeNest(common, SeqStmt::Flatten(init, update));
+    } else {
+      // When init op is not available, use body op for reset in the first iter.
+      CHECK(intrin->body.defined())
+          << "Normal body op for intrin " << intrin << " is not defined";
+      Stmt update = TransformUpdate(stage, dom_map, n,
+                                    intrin->body,
+                                    intrin->reduce_update);
+      update = MergeNest(output_bind_nest, update);
+      update = MergeNest(input_bind_nest, update);
+      update = ir::Substitute(update, vmap);
+      update = MergeNest(binder.asserts(), update);
+      update = top::Substitute(update, n.main_vmap);
+      update = MergeNest(update_nest, update);
+      return MergeNest(common, update);
+    }
+  }
+}
+
+// Register functions for unittests
+TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    Stage stage = args[0];
+    Map<IterVar, Range> dmap = args[1];
+    std::unordered_map<IterVar, Range> out_dom;
+    std::unordered_map<Tensor, Array<Range> > in_region;
+    CHECK(stage->op.as<ComputeOpNode>());
+    InferTensorizeRegion(stage->op.as<ComputeOpNode>(),
+                         stage,
+                         as_unordered_map(dmap),
+                         &out_dom, &in_region);
+    *ret = Array<ObjectRef>{Map<IterVar, Range>(out_dom),
+                          Map<Tensor, Array<Range> >(in_region)};
+  });
+
+TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    Stage stage = args[0];
+    Map<IterVar, Range> out_dom = args[1];
+    Map<Tensor, Array<Range> > in_region = args[2];
+    TensorIntrin intrin = args[3];
+    Map<Var, Range> vrange;
+    CHECK(stage->op.as<ComputeOpNode>());
+    *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
+                              stage,
+                              {{}},
+                              as_unordered_map(out_dom),
+                              as_unordered_map(in_region),
+                              intrin,
+                              &vrange);
+  });
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/schedule/auto_inline_elem_wise.cc b/src/top/schedule/auto_inline_elem_wise.cc
new file mode 100644 (file)
index 0000000..9b08813
--- /dev/null
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file auto_inline_elem_wise.cc
+ */
+#include <tvm/top/schedule_pass.h>
+#include <tvm/top/operation.h>
+#include <tvm/ir_functor_ext.h>
+
+namespace tvm {
+namespace top {
+
+using namespace ir;
+
+class ElemWiseDetector : public ir::ExprVisitor {
+ public:
+  explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
+
+  void VisitExpr(const PrimExpr& e) final {
+    if (!is_elem_wise_) return;
+    ExprVisitor::VisitExpr(e);
+  }
+
+  void VisitExpr_(const CallNode* op) final {
+    Array<PrimExpr> axis = op->args;
+    if (axis_.size() != axis.size()) {
+      is_elem_wise_ = false;
+      return;
+    }
+
+    for (size_t i = 0; i < axis_.size(); ++i) {
+      if (!axis[i].same_as(axis_[i]->var)) {
+        is_elem_wise_ = false;
+        return;
+      }
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  bool is_elem_wise_{true};
+
+ private:
+  Array<IterVar> axis_;
+};
+
+
+bool IsElemWise(const Operation& op) {
+  if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
+    ElemWiseDetector v = ElemWiseDetector(compute->axis);
+    for (auto& e : compute->body) v(e);
+    return v.is_elem_wise_;
+  }
+  return false;
+}
+
+void AutoInlineElemWise(Schedule sch) {
+  for (Stage s : sch->stages) {
+    if (!s.is_scheduled() && IsElemWise(s->op) && !s->is_output) {
+      s.compute_inline();
+    }
+  }
+}
+
+bool IsBroadcast(const Operation& op) {
+  if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
+    if (compute->reduce_axis.size()) {
+      return false;
+    }
+    // TODO(nicolasvasilache): Implement Me
+  }
+  return false;
+}
+
+void AutoInlineBroadcast(Schedule sch) {
+  for (Stage s : sch->stages) {
+    if (!s.is_scheduled() && IsBroadcast(s->op) && !s->is_output) {
+      s.compute_inline();
+    }
+  }
+}
+
+bool IsInjective(const Operation& op) {
+  if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
+    return compute->reduce_axis.size() == 0;
+  }
+  return false;
+}
+
+void AutoInlineInjective(Schedule sch) {
+  for (Stage s : sch->stages) {
+    if (!s.is_scheduled() && IsInjective(s->op) && !s->is_output) {
+      s.compute_inline();
+    }
+  }
+}
+
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/schedule/bound.cc b/src/top/schedule/bound.cc
new file mode 100644 (file)
index 0000000..8fffc53
--- /dev/null
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bound.cc
+ * \brief The bound inference logic.
+ */
+#include <tvm/top/schedule_pass.h>
+#include <tvm/top/operation.h>
+#include <tvm/ir_pass.h>
+#include <unordered_map>
+#include <unordered_set>
+#include "graph.h"
+#include "message_passing.h"
+#include "../../runtime/thread_storage_scope.h"
+
+namespace tvm {
+namespace top {
+
+using runtime::StorageRank;
+using runtime::StorageScope;
+using runtime::ThreadScope;
+
+/*! \brief The graph context used during bound inference. */
+struct GraphContext {
+  /*! \brief The feed graph */
+  FeedGraph feed_graph;
+  /*! \brief Attachment path */
+  AttachPath attach_path;
+  /*! \brief The bind map */
+  std::unordered_map<IterVar, IterVar> bind_map;
+  /*! \brief map from op to stage */
+  std::unordered_map<const Object*, Stage> op2stage_;
+};
+
+bool NeedRelax(const IterVar& iv,
+               bool found_attach,
+               const std::unordered_map<IterVar, IterVar>& bind_map,
+               const runtime::StorageScope& scope) {
+  auto it = bind_map.find(iv);
+  const std::string& tag = (
+      it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
+  if (tag.length() == 0 || tag == "pipeline") {
+    return !found_attach;
+  }
+  ThreadScope ts = ThreadScope::make(tag);
+
+  // When there is warp memory
+  // threadIdx.x must be set to be warp index.
+  if (scope.rank == StorageRank::kWarp &&
+      ts.rank == 1 &&
+      ts.dim_index == 0) {
+    return true;
+  }
+  return static_cast<int>(scope.rank) <= ts.rank;
+}
+
+// infer storage scope, if not given
+StorageScope InferStorageScope(
+    const Stage& stage, const GraphContext& ctx) {
+  if (stage->scope.length() != 0) {
+    return StorageScope::make(stage->scope);
+  }
+  int max_rank = -1;
+  for (IterVar iv : ctx.attach_path.at(stage->op)) {
+    auto it = ctx.bind_map.find(iv);
+    const std::string& tag = (
+        it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
+    if (tag != "pipeline" && tag.length() != 0) {
+      max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
+    }
+  }
+  StorageScope s;
+  s.rank = runtime::DefaultStorageRank(max_rank);
+  return s;
+}
+
+
+void InferRootBound(const Stage& stage,
+                    const GraphContext& ctx,
+                    std::unordered_map<IterVar, Range>* rmap) {
+  CHECK_NE(stage->attach_type, kInline)
+      << "call schedule.normalize before scheduleops";
+  if (stage->attach_type == kInlinedAlready) return;
+  if (stage->is_output) {
+    // verify correctness.
+    CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
+          << "Output must be attached at root";
+  }
+  if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
+    for (auto iv :  stage->op->root_iter_vars()) {
+      CHECK(iv->dom.defined());
+      CHECK(!rmap->count(iv));
+      (*rmap)[iv] = iv->dom;
+    }
+    return;
+  }
+  // The tensor domain.
+  std::unordered_map<Tensor, TensorDom> tmap;
+  // The consumers of the op.
+  std::unordered_set<Operation> consumers;
+  for (int i = 0; i < stage->op->num_outputs(); ++i) {
+    Tensor t = stage->op.output(i);
+    tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
+    auto it = ctx.feed_graph.find(t);
+    if (it != ctx.feed_graph.end()) {
+      for (const Operation& op : it->second) {
+        consumers.insert(op);
+      }
+    } else {
+      LOG(INFO) << "not in feed graph consumer = " << stage->op;
+    }
+  }
+  // storage scope.
+  runtime::StorageScope scope = InferStorageScope(stage, ctx);
+  // Bound prop by other consumers.
+  // - Compute bound by relaxation rules: NeedRelax
+  //   - For normal index, use relative location of loop nest./
+  //   - For thread index, use the thread scope.
+  //
+  Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
+  // The parent set.
+  for (const Operation& op : consumers) {
+    std::unordered_map<const VarNode*, IntSet> relax_set;
+    std::unordered_map<IterVar, IntSet> up_state;
+    bool found_attach = false;
+    CHECK(ctx.op2stage_.count(op.get()));
+    const Stage& op_stage = ctx.op2stage_.at(op.get());
+    // Consumer nest
+    for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) {
+      IterVar iv = op_stage->leaf_iter_vars[i - 1];
+      if (stage_attach.size() != 0 && iv == stage_attach[0]) {
+        found_attach = true;
+      }
+      auto it = rmap->find(iv);
+      CHECK(it != rmap->end());
+      const Range& vrange = it->second;
+      if (is_one(vrange->extent)) {
+        up_state[iv] = IntSet::single_point(vrange->min);
+      } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
+        CHECK(is_zero(vrange->min))
+            << "InferBound requires every leaf iter var's min equals 0, "
+            << " call schedule.normalize to achieve this. ";
+        if (ctx.bind_map.count(iv)) {
+          up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
+        } else {
+          up_state[iv] = IntSet::single_point(iv->var);
+        }
+      } else {
+        up_state[iv] = IntSet::range(vrange);
+      }
+    }
+    // Consumer's attach nest
+    for (IterVar iv : ctx.attach_path.at(op)) {
+      if (stage_attach.size() != 0 && iv == stage_attach[0]) {
+        found_attach = true;
+      }
+      Range vrange = rmap->at(iv);
+      CHECK(is_zero(vrange->min))
+          << "InferBound requires every leaf iter var's min equals 0, "
+          << "call schedule.normalize to achieve this.";
+      if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
+        relax_set[iv->var.get()] = IntSet::range(vrange);
+        if (ctx.bind_map.count(iv)) {
+          relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
+        }
+      }
+    }
+    CHECK(found_attach || stage_attach.size() == 0)
+        << "Invalid Schedule, cannot find the producer " << stage->op
+        << " along the loop nest specified by compute_at of consumer " << op;
+    // Get the domain of the consumer
+    PassUpDomain(op_stage, *rmap, &up_state);
+    // Relax if needed.
+    std::unordered_map<const VarNode*, IntSet> dom_map;
+    arith::Analyzer analyzer;
+    for (auto iv : op->root_iter_vars()) {
+      Range r;
+      if (up_state.count(iv)) {
+        r = up_state.at(iv).cover_range(iv->dom);
+      } else {
+        r = iv->dom;
+      }
+      if (relax_set.size() != 0) {
+        dom_map[iv->var.get()] = EvalSet(r, relax_set);
+      } else {
+        dom_map[iv->var.get()] = IntSet::range(r);
+      }
+      analyzer.Bind(iv->var, r);
+    }
+    op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
+  }
+  stage->op->GatherBound(stage->op, tmap, rmap);
+}
+
+Map<IterVar, Range> InferBound(const Schedule& sch) {
+  // Prepare context
+  GraphContext ctx;
+  Array<Operation> roots;
+  arith::Analyzer analyzer;
+
+  for (Operation op : sch->outputs) {
+    roots.push_back(sch->stage_map[op]->op);
+  }
+  ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
+
+  for (Stage stage : sch->stages) {
+    for (auto kv : stage->iter_var_attrs) {
+      if (kv.second->bind_thread.defined()) {
+        CHECK(!ctx.bind_map.count(kv.first));
+        ctx.bind_map[kv.first] = kv.second->bind_thread;
+      }
+    }
+    ctx.op2stage_[stage->op.get()] = stage;
+  }
+  ctx.attach_path = CreateAttachPath(sch);
+  // Run inference.
+  std::unordered_map<IterVar, Range> ret;
+  for (size_t i = sch->stages.size(); i != 0; --i) {
+    const Stage& stage = sch->stages[i - 1];
+    InferRootBound(stage, ctx, &ret);
+
+    // bind bound of root iter vars.
+    for (auto iv : stage->op->root_iter_vars()) {
+      auto it = ret.find(iv);
+      if (it != ret.end()) {
+        analyzer.Bind(iv->var, it->second);
+      }
+    }
+
+    // pass down to get bound of all iter vars.
+    PassDownDomain(stage, &ret, &analyzer);
+    for (IterVar iv : stage->env_threads) {
+      CHECK(iv->dom.defined());
+      ret[iv] = iv->dom;
+    }
+  }
+  for (auto& p : ret) {
+    ret[p.first] = Range::make_by_min_extent(
+        analyzer.Simplify(p.second->min),
+        analyzer.Simplify(p.second->extent));
+  }
+  return Map<IterVar, Range>(ret.begin(), ret.end());
+}
+
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/schedule/graph.cc b/src/top/schedule/graph.cc
new file mode 100644 (file)
index 0000000..62df842
--- /dev/null
@@ -0,0 +1,433 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph.cc
+ * \brief Utilities to get information about schedule graph.
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/top/operation.h>
+#include <utility>
+#include <unordered_set>
+#include <unordered_map>
+#include "graph.h"
+
+namespace tvm {
+namespace top {
+// key to specific tensor dimension.
+struct TensorDimKey {
+  ir::FunctionRef f;
+  int value_index;
+  int dim;
+  TensorDimKey() {}
+  TensorDimKey(const ir::CallNode* op, int dim)
+      : f(op->func), value_index(op->value_index), dim(dim) {
+  }
+  TensorDimKey(const Tensor& t, int dim)
+      : f(t->op), value_index(t->value_index), dim(dim) {
+  }
+  TensorDimKey(const Tensor& t, size_t dim)
+      : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {
+  }
+  inline bool operator==(const TensorDimKey& other) const {
+    return f == other.f &&
+        value_index == other.value_index &&
+        dim == other.dim;
+  }
+  inline bool operator!=(const TensorDimKey& other) const {
+    return !operator==(other);
+  }
+};
+}  // namespace top
+}  // namespace tvm
+
+namespace std {
+template <>
+struct hash<::tvm::top::TensorDimKey> {
+  std::size_t operator()(const ::tvm::top::TensorDimKey& k) const {
+    size_t lhs = ::tvm::ObjectHash()(k.f);
+    size_t rhs = static_cast<size_t>(k.value_index) << 16UL |
+        static_cast<size_t>(k.dim);
+    lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
+    return lhs;
+  }
+};
+}  // namespace std
+
+
+namespace tvm {
+namespace top {
+
+// construct a read graph that gives readers of each operation
+// that the root depend on
+ReadGraph CreateReadGraph(const Array<Operation>& roots) {
+  ReadGraph rmap;
+  std::vector<Operation> stack;
+  std::unordered_set<const Object*> visited;
+  // initialize the roots
+  for (Operation op : roots) {
+    stack.push_back(op);
+    visited.insert(op.get());
+  }
+
+  while (!stack.empty()) {
+    Operation op = stack.back();
+    stack.pop_back();
+    Array<Tensor> deps = op->InputTensors();
+    rmap.Set(op, deps);
+    for (Tensor t : deps) {
+      if (t->op.defined() && visited.count(t->op.get()) == 0) {
+        visited.insert(t->op.get());
+        stack.push_back(t->op);
+      }
+    }
+  }
+  return rmap;
+}
+
+// Do DFS visit to get the subgraph.
+// Return if op is inside the subgraph.
+bool GetSubGraphByPostDFS_(
+    const Operation& op,
+    const std::unordered_set<const Object*>& boundary,
+    bool include_bounary,
+    std::unordered_map<const Object*, bool>* visited,
+    Array<Operation>* result) {
+  if (visited->count(op.get())) {
+    return visited->at(op.get());
+  }
+  if (boundary.count(op.get())) {
+    (*visited)[op.get()] = true;
+    if (include_bounary) {
+      result->push_back(op);
+    }
+    return true;
+  }
+  // mark to avoid loop
+  // Not necessary for DAG.
+  (*visited)[op.get()] = false;
+  // check if we can reach boundary.
+  bool reach_boundary = false;
+  for (Tensor t : op->InputTensors()) {
+    if (GetSubGraphByPostDFS_(t->op, boundary,
+                              include_bounary,
+                              visited, result)) {
+      reach_boundary = true;
+    }
+  }
+  (*visited)[op.get()] = reach_boundary;
+  if (reach_boundary) {
+    result->push_back(op);
+  }
+  return reach_boundary;
+}
+
+Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
+                             const Array<Tensor>& inputs,
+                             bool include_inputs) {
+  Array<Operation> result;
+  std::unordered_set<const Object*> boundary;
+  for (Tensor t : inputs) {
+    boundary.insert(t->op.get());
+  }
+  std::unordered_map<const Object*, bool> visited;
+  for (Tensor t : outputs) {
+    GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
+                          &visited, &result);
+  }
+  return result;
+}
+
+
+void PostDFSOrder(const Operation& op,
+                  const ReadGraph& g,
+                  std::unordered_set<Operation>* visited,
+                  Array<Operation>* post_order) {
+  if (visited->count(op)) return;
+  visited->insert(op);
+  for (const auto& t : g.at(op)) {
+    PostDFSOrder(t->op, g, visited, post_order);
+  }
+  post_order->push_back(op);
+}
+
+Array<Operation> PostDFSOrder(
+    const Array<Operation>& roots,
+    const ReadGraph& g) {
+  std::unordered_set<Operation> visited;
+  Array<Operation> post_order;
+  for (Operation op : roots) {
+    PostDFSOrder(op, g, &visited, &post_order);
+  }
+  return post_order;
+}
+
+FeedGraph CreateFeedGraph(const ReadGraph& g) {
+  FeedGraph fg;
+  for (auto kv : g) {
+    for (Tensor t : kv.second) {
+      fg[t].push_back(kv.first);
+    }
+  }
+  return fg;
+}
+
+AttachPath CreateAttachPath(Schedule sch) {
+  AttachPath ret;
+  for (Stage stage : sch->stages) {
+    std::unordered_set<const Object*> visited;
+    Array<IterVar> path;
+    for (Stage s = stage; s.defined();) {
+      CHECK(!visited.count(s.get()))
+          << "Find loop in compute_at attach group";
+      visited.insert(s.get());
+      Stage spec = s.GetAttachSpec();
+      bool start_attach;
+      IterVar attach_ivar;
+      if (spec->attach_type == kScope) {
+        attach_ivar = spec->attach_ivar;
+        s = spec->attach_stage;
+        start_attach = false;
+        CHECK(attach_ivar.defined());
+      } else if (spec->attach_type == kScanUpdate) {
+        s = spec->attach_stage;
+        start_attach = true;
+      } else {
+        break;
+      }
+      CHECK(s.defined());
+      for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
+        IterVar iv = s->leaf_iter_vars[i - 1];
+        if (!start_attach && iv.same_as(attach_ivar)) {
+          start_attach = true;
+        }
+        if (start_attach) path.push_back(iv);
+      }
+      CHECK(start_attach)
+          << "Invalid Schedule: cannot find attach point " << attach_ivar
+          << " in the schedule of " << s->op;
+    }
+    if (!ret.count(stage->op)) {
+      ret.Set(stage->op, path);
+    }
+  }
+  return ret;
+}
+
+// graph of push reach relation of tensor dimensions
+using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;
+
+ReachGraph GetReachGraph(const Array<Operation>& ops) {
+  ReachGraph reach;
+  std::unordered_set<const Object*> bset;
+  for (size_t i = 0; i < ops.size(); ++i) {
+    bset.insert(ops[i].get());
+  }
+
+  for (Operation op : ops) {
+    if (const auto* scan_op = op.as<ScanOpNode>()) {
+      const auto& update = scan_op->update;
+      const auto& init = scan_op->init;
+      for (size_t i = 0; i < update.size(); ++i) {
+        Tensor t = op.output(i);
+        for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
+          reach[TensorDimKey(t, k)].emplace_back(
+              TensorDimKey(update[i], k));
+          reach[TensorDimKey(t, k)].emplace_back(
+              TensorDimKey(init[i], k));
+        }
+      }
+    } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
+      std::unordered_map<const Object*, TensorDimKey> vmap;
+      const auto& axis = compute_op->axis;
+      Tensor t = op.output(0);
+      for (size_t i = 0; i < axis.size(); ++i) {
+        vmap[axis[i]->var.get()] = TensorDimKey(t, i);
+        reach[TensorDimKey(t, i)] = {};
+      }
+      auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) {
+        const ir::CallNode *call = n.as<ir::CallNode>();
+        if (call != nullptr && call->func.defined()) {
+          if (!bset.count(call->func.get())) return;
+          for (size_t i = 0; i < call->args.size(); ++i) {
+            TensorDimKey dkey(call, static_cast<int>(i));
+            auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) {
+              const VarNode *v = node.as<VarNode>();
+              auto it = vmap.find(v);
+              if (it != vmap.end()) {
+                reach[it->second].push_back(dkey);
+              }
+            };
+            ir::PostOrderVisit(call->args[i], fpush);
+          }
+        }
+      };
+      for (auto& e : compute_op->body) {
+        ir::PostOrderVisit(e, fvisit);
+      }
+    }
+  }
+  return reach;
+}
+
+Array<Operation> ScanGetBody(const Operation& scan_op) {
+  const ScanOpNode* scan = scan_op.as<ScanOpNode>();
+  // Get the body.
+  Array<Tensor> inputs;
+  for (Tensor t : scan->state_placeholder) {
+    inputs.push_back(t);
+  }
+  for (Tensor t : scan->inputs) {
+    inputs.push_back(t);
+  }
+  return GetSubGraph(scan->update, inputs, false);
+}
+
+Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
+  const ScanOpNode* scan = scan_op.as<ScanOpNode>();
+  Array<Operation> body = ScanGetBody(scan_op);
+
+  std::unordered_map<TensorDimKey, const Object*> exact_reach;
+  std::unordered_set<const Object*> fail_set;
+
+  for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
+    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
+      TensorDimKey key(scan->state_placeholder[i], k);
+      exact_reach[key] = scan->spatial_axis_[sp_idx].get();
+    }
+  }
+  // merge exact reach
+  auto f_merge_key = [&exact_reach, &fail_set](
+      const TensorDimKey& dst, const TensorDimKey& src) {
+    auto sit = exact_reach.find(src);
+    if (sit == exact_reach.end()) return;
+    auto dit = exact_reach.find(dst);
+    if (dit == exact_reach.end()) {
+      exact_reach[dst] = sit->second;
+    } else {
+      if (dit->second != sit->second) {
+        fail_set.insert(dit->second);
+        fail_set.insert(sit->second);
+      }
+    }
+  };
+  // prop exact reach back.
+  for (size_t i = 0; i < body.size(); ++i) {
+    const Operation& op = body[i];
+    if (const auto* scan_op = op.as<ScanOpNode>()) {
+      const auto& update = scan_op->update;
+      const auto& init = scan_op->init;
+      for (size_t i = 0; i < update.size(); ++i) {
+        Tensor t = op.output(i);
+        for (size_t k = 1; k < update[i]->shape.size(); ++k) {
+          f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
+          f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
+        }
+      }
+    } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
+      std::unordered_map<const Object*, std::vector<TensorDimKey> > vmap;
+      const auto& axis = compute_op->axis;
+      for (size_t i = 0; i < axis.size(); ++i) {
+        std::vector<TensorDimKey> keys;
+        for (int j = 0; j < op->num_outputs(); ++j) {
+          keys.emplace_back(op.output(j), i);
+        }
+        vmap[axis[i]->var.get()] = std::move(keys);
+      }
+      auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
+          const ObjectRef& n) {
+        const ir::CallNode *call = n.as<ir::CallNode>();
+        if (call != nullptr && call->func.defined()) {
+          for (size_t i = 0; i < call->args.size(); ++i) {
+            auto it = vmap.find(call->args[i].get());
+            TensorDimKey src(call, static_cast<int>(i));
+            if (it != vmap.end()) {
+              const std::vector<TensorDimKey>& keys = it->second;
+              for (const auto& key : keys) {
+                f_merge_key(key, src);
+              }
+            } else {
+              if (exact_reach.count(src)) {
+                fail_set.insert(exact_reach.at(src));
+              }
+            }
+          }
+        }
+      };
+      for (auto& e : compute_op->body) {
+        ir::PostOrderVisit(e, fvisit);
+      }
+    }
+  }
+  ReachGraph reach;
+  Map<IterVar, PrimExpr> ret;
+  std::unordered_set<TensorDimKey> place_holder_ref;
+  for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
+    for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
+      place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
+    }
+  }
+
+  for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
+    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
+      TensorDimKey key(scan->update[i], k);
+      TensorDimKey target(scan->state_placeholder[i], k);
+      IterVar sp_iv = scan->spatial_axis_[sp_idx];
+      if (fail_set.count(sp_iv.get()) ||
+          !exact_reach.count(key) ||
+          exact_reach.at(key) != sp_iv.get()) {
+        ret.Set(sp_iv, make_const(DataType::Int(32), 0));
+      } else {
+        // now we proved exact match, need to prove no interference with other graph.
+        if (reach.size() == 0) reach = GetReachGraph(body);
+        // do a DFS
+        std::unordered_set<TensorDimKey> visited;
+        std::vector<TensorDimKey> stack{key};
+        visited.insert(key);
+        while (!stack.empty()) {
+          TensorDimKey k = stack.back();
+          if (k != target && place_holder_ref.count(k)) break;
+          stack.pop_back();
+          if (!reach.count(k)) {
+            LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
+          }
+
+          for (TensorDimKey kk : reach.at(k)) {
+            if (visited.count(kk)) {
+              continue;
+            }
+            visited.insert(kk);
+            stack.push_back(kk);
+          }
+        }
+        if (!stack.empty()) {
+          // failed the prove.
+          ret.Set(sp_iv, make_const(DataType::Int(32), 0));
+        } else {
+          ret.Set(sp_iv, make_const(DataType::Int(32), 1));
+        }
+      }
+    }
+  }
+  return ret;
+}
+
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/schedule/graph.h b/src/top/schedule/graph.h
new file mode 100644 (file)
index 0000000..f379f98
--- /dev/null
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph.h
+ * \brief Utilities to get information about schedule graph.
+ */
+#ifndef TVM_TOP_SCHEDULE_GRAPH_H_
+#define TVM_TOP_SCHEDULE_GRAPH_H_
+
+#include <tvm/expr.h>
+#include <tvm/top/schedule.h>
+#include <tvm/top/operation.h>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace top {
+
+/*!
+ * \brief data structure of Operation->Tensors it reads
+ */
+using ReadGraph = Map<Operation, Array<Tensor> >;
+
+/*!
+ * \brief AttachPath maps op-> a list of IterVar
+ */
+using AttachPath = Map<Operation, Array<IterVar> >;
+
+/*!
+ * \brief The map between tensor and operation it feeds to.
+ */
+using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
+
+/*!
+ * \brief Get read graph of each operation to all the
+ *  Tensors that it directly depends on.
+ *
+ *  The result map contains Operations needed to finish root Operation.
+ * \param roots The root operation.
+ * \return The result map.
+ */
+ReadGraph CreateReadGraph(const Array<Operation>& roots);
+
+/*!
+ * \brief Get minimum subgraph between outputs and inputs.
+ *  The operations contains node which input-reachable from any inputs
+ *  output reachable to any outputs.
+ *
+ *  The inputs won't be included in the subgraph, the outputs will be included.
+ *
+ * \param outputs The outputs of the subgraph
+ * \param inputs The inputs to the subgraph.
+ * \param include_inputs Whether to include inputs
+ *
+ * \return The subgraph.
+ */
+Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
+                             const Array<Tensor>& inputs,
+                             bool include_inputs);
+
+/*!
+ * \brief Get a post DFS ordered of operations in the graph.
+ * \param roots The root of the graph.
+ * \param g The read graph.
+ * \return vector order of Operations in PostDFS order.
+ *
+ * \note PostDFSOrder is a special case of Topoligical order,
+ *   and can be used when topoligical order is needed.
+ */
+Array<Operation> PostDFSOrder(
+    const Array<Operation>& roots, const ReadGraph& g);
+
+/*!
+ * \brief Create feedgraph for given Schedule
+ * \param  g The read graph.
+ * \return The created feedgraph.
+ */
+FeedGraph CreateFeedGraph(const ReadGraph& g);
+
+/*!
+ * \brief Create AttachPath that  maps op-> a list of IterVar
+ *  That represents the loop nest op sits in from inner most to outermost
+ *  Also inserts attach_stage for scan updates when needed.
+ *
+ * \param sch The schedule.
+ * \return The attach path.
+ */
+AttachPath CreateAttachPath(Schedule sch);
+
+/*!
+ * \brief Get all operations inside the recursion of scan.
+ * \param scan_op The scan node ops.
+ * \return The body operations, in read dependency order.
+ */
+Array<Operation> ScanGetBody(const Operation& scan_op);
+
+/*!
+ * \brief Analyze each spatial dimension of scan's result.
+ *  Give check on whether each dimension is fix point,
+ *  An axis is a fixed point if it only refers back to itself in recursion
+ *  and it is not used in axis of other recursion field.
+ *
+ *  next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
+ *
+ * \param scan The scan node.
+ * \return Map of spatial_axis -> IntImm
+ */
+Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan);
+
+}  // namespace top
+}  // namespace tvm
+
+#endif  // TVM_TOP_SCHEDULE_GRAPH_H_
diff --git a/src/top/schedule/message_passing.cc b/src/top/schedule/message_passing.cc
new file mode 100644 (file)
index 0000000..4732681
--- /dev/null
@@ -0,0 +1,543 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file message_passing.cc
+ * \brief The message passing domain.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include "message_passing.h"
+#include "../../arith/compute_expr.h"
+
+namespace tvm {
+namespace top {
+
+using namespace ir;
+
+void Update(std::unordered_map<IterVar, Range>* p_state,
+            const IterVar& iv,
+            Range r,
+            arith::Analyzer* analyzer) {
+  auto it = p_state->find(iv);
+  if (it == p_state->end()) {
+    (*p_state)[iv] = r;
+    analyzer->Bind(iv->var, r);
+  } else {
+    bool match = is_zero(it->second->min) &&
+        analyzer->CanProve(r->extent - it->second->extent == 0);
+    CHECK(match)
+        << iv
+        << " domain already inferred,"
+        << " cannot prove their extents are the same "
+        << it->second->extent << " vs " << r->extent;
+  }
+}
+
+void PassDownDomain(const Stage& stage,
+                    std::unordered_map<IterVar, Range>* p_state,
+                    arith::Analyzer* actx,
+                    bool allow_missing) {
+  auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
+    if (actx->CanProve(indexmod(a, b) == 0)) {
+      return actx->Simplify(indexdiv(a, b));
+    }
+    return actx->Simplify(indexdiv(a + (b - 1), b));
+  };
+
+  auto& state = *p_state;
+  // forwar iteration on relations
+  for (IterVarRelation rel : stage->relations) {
+    if (const SplitNode* r = rel.as<SplitNode>()) {
+      if (!state.count(r->parent)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      CHECK(!state.count(r->inner));
+      const Range& range_parent = state.at(r->parent);
+      if (r->factor.defined()) {
+        Update(p_state, r->inner,
+               Range::make_by_min_extent(0, r->factor), actx);
+        Update(p_state, r->outer,
+               Range::make_by_min_extent(
+                   0, ceil_div(range_parent->extent, r->factor)), actx);
+      } else {
+        Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
+        Update(p_state, r->inner,
+               Range::make_by_min_extent(
+                   0, ceil_div(range_parent->extent, r->nparts)), actx);
+      }
+    } else if (const FuseNode* r = rel.as<FuseNode>()) {
+      if (!state.count(r->outer) || !state.count(r->inner)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      const Range& range_outer = state.at(r->outer);
+      const Range& range_inner = state.at(r->inner);
+      state[r->fused] = Range::make_by_min_extent(
+          0, range_outer->extent * range_inner->extent);
+    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
+      if (!state.count(r->parent)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      Update(p_state, r->rebased,
+             Range::make_by_min_extent(
+                 0, state.at(r->parent)->extent), actx);
+    } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
+      Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx);
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+  // update the extents of binded threads.
+  for (auto kv : stage->iter_var_attrs) {
+    if (kv.second->bind_thread.defined()) {
+      CHECK(state.count(kv.first));
+      Update(p_state, kv.second->bind_thread, state.at(kv.first), actx);
+    }
+  }
+}
+
+void PassUpIndex(const Stage& stage,
+                 const Map<IterVar, Range>& dom_map,
+                 std::unordered_map<IterVar, PrimExpr>* p_state,
+                 bool allow_missing) {
+  auto& state = *p_state;
+  for (size_t i = stage->relations.size(); i != 0; --i) {
+    IterVarRelation rel = stage->relations[i - 1];
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      if (!state.count(s->outer) || !state.count(s->inner)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      PrimExpr outer = state.at(s->outer);
+      PrimExpr inner = state.at(s->inner);
+      PrimExpr factor = dom_map.at(s->inner)->extent;
+      PrimExpr parent_min = dom_map.at(s->parent)->min;
+      state[s->parent] = inner + outer * factor;
+      // add min if they exist
+      if (!is_zero(parent_min)) {
+        state[s->parent] = state[s->parent] + parent_min;
+      }
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      if (!state.count(s->fused)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      PrimExpr value = state.at(s->fused);
+      PrimExpr factor = dom_map.at(s->inner)->extent;
+      PrimExpr outer_min = dom_map.at(s->outer)->min;
+      PrimExpr inner_min = dom_map.at(s->inner)->min;
+      state[s->outer] = indexdiv(value, factor);
+      state[s->inner] = indexmod(value, factor);
+      // add min if they exist
+      if (!is_zero(outer_min)) {
+        state[s->outer] = state[s->outer] + outer_min;
+      }
+      if (!is_zero(inner_min)) {
+        state[s->inner] = state[s->inner] + inner_min;
+      }
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      if (!state.count(s->rebased)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      PrimExpr value = state.at(s->rebased);
+      PrimExpr parent_min = dom_map.at(s->parent)->min;
+      // add min if they exist
+      if (!is_zero(parent_min)) {
+        state[s->parent] = value + parent_min;
+      } else {
+        state[s->parent] = value;
+      }
+    } else if (rel.as<SingletonNode>()) {
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+void PassDownIndex(const Stage& stage,
+                   const Map<IterVar, Range>& dom_map,
+                   std::unordered_map<IterVar, PrimExpr>* p_state,
+                   bool allow_missing) {
+  auto& state = *p_state;
+  for (IterVarRelation rel : stage->relations) {
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      if (!state.count(s->parent)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      Range r = dom_map.at(s->inner);
+      CHECK(is_zero(r->min));
+      PrimExpr parent = state.at(s->parent);
+      PrimExpr factor = r->extent;
+      state[s->outer] = indexdiv(parent, factor);
+      state[s->inner] = indexmod(parent, factor);
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      if (!state.count(s->inner) && !state.count(s->outer)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      PrimExpr factor = dom_map.at(s->inner)->extent;
+      PrimExpr outer_min = dom_map.at(s->outer)->min;
+      PrimExpr inner_min = dom_map.at(s->inner)->min;
+      PrimExpr inner = state.at(s->inner);
+      PrimExpr outer = state.at(s->outer);
+      CHECK(is_zero(outer_min));
+      CHECK(is_zero(inner_min));
+      state[s->fused] = outer * factor + inner;
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      if (!state.count(s->rebased)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      PrimExpr value = state.at(s->parent);
+      PrimExpr parent_min = dom_map.at(s->parent)->min;
+      CHECK(is_zero(parent_min));
+      state[s->rebased] = value;
+    } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
+      state[s->iter] = make_zero(s->iter->var.dtype());
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+// Domain message passing.
+void PassUpDomain(const SplitNode* s,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  const IntSet& outer,
+                  const IntSet& inner,
+                  IntSet* parent) {
+  if (dom_map.count(s->outer) &&
+      dom_map.count(s->inner) &&
+      dom_map.count(s->parent) &&
+      outer.match_range(dom_map.at(s->outer)) &&
+      inner.match_range(dom_map.at(s->inner))) {
+    *parent = IntSet::range(dom_map.at(s->parent));
+    return;
+  }
+  PrimExpr factor = dom_map.at(s->inner)->extent;
+  PrimExpr parent_min = dom_map.at(s->parent)->min;
+  CHECK(outer.defined());
+  CHECK(inner.defined());
+  CHECK(factor.defined());
+  *parent = arith::EvalSet(
+      s->outer->var * factor + s->inner->var + parent_min,
+      {{s->outer, outer}, {s->inner, inner}});
+}
+
+void PassUpDomain(const FuseNode* s,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  const IntSet& fused,
+                  IntSet* outer,
+                  IntSet* inner) {
+  CHECK(dom_map.count(s->outer));
+  CHECK(dom_map.count(s->inner));
+  CHECK(dom_map.count(s->fused));
+
+  if (fused.match_range(dom_map.at(s->fused))) {
+    *outer = IntSet::range(dom_map.at(s->outer));
+    *inner = IntSet::range(dom_map.at(s->inner));
+    return;
+  }
+  PrimExpr outer_min = dom_map.at(s->outer)->min;
+  PrimExpr inner_min = dom_map.at(s->inner)->min;
+
+  if (fused.is_single_point()) {
+    PrimExpr value = fused.point_value();
+    PrimExpr factor = dom_map.at(s->inner)->extent;
+    PrimExpr v_outer  = indexdiv(value, factor);
+    PrimExpr v_inner  = indexmod(value, factor);
+    if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
+    if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
+    *outer = IntSet::single_point(v_outer);
+    *inner = IntSet::single_point(v_inner);
+  } else {
+    PrimExpr fused_extent = (fused.max() - fused.min() + 1);
+    PrimExpr inner_extent = dom_map.at(s->inner)->extent;
+    *outer = IntSet::interval(
+        outer_min + indexdiv(fused.min(), inner_extent),
+        outer_min + indexdiv(fused.max(), inner_extent));
+    if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) &&
+        is_zero(Simplify(indexmod(fused.min(), fused_extent)))) {
+      // fused never spans multiple rows, make a tight bounding box
+      // there may be other cases when bounding box could be tightened
+      *inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent),
+                                inner_min + indexmod(fused.max(), inner_extent));
+    } else {  // fused may span multiple rows, use full row widths
+      if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) ||
+          !is_zero(Simplify(indexmod(fused.min(), inner_extent)))) {
+        LOG(WARNING) <<
+          "fused and original axes are not aligned, this may cause redundant computations";
+      }
+      *inner = IntSet::range(dom_map.at(s->inner));
+    }
+    return;
+  }
+}
+
+void PassUpDomain(const RebaseNode* s,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  const IntSet& rebased,
+                  IntSet* parent) {
+  CHECK(dom_map.count(s->parent));
+  if (rebased.match_range(dom_map.at(s->rebased))) {
+    *parent = IntSet::range(dom_map.at(s->parent));
+    return;
+  }
+  PrimExpr parent_min = dom_map.at(s->parent)->min;
+  *parent = arith::EvalSet(s->rebased->var + parent_min,
+                           {{s->rebased, rebased}});
+}
+
+void PassUpDomain(const Stage& stage,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  std::unordered_map<IterVar, IntSet>* p_state) {
+  auto& state = *p_state;
+  for (size_t i = stage->relations.size(); i != 0; --i) {
+    IterVarRelation rel = stage->relations[i - 1];
+    if (const SplitNode* r = rel.as<SplitNode>()) {
+      IntSet parent;
+      PassUpDomain(r, dom_map,
+                   state.at(r->outer), state.at(r->inner),
+                   &parent);
+      state[r->parent] = parent;
+    } else if (const FuseNode* r = rel.as<FuseNode>()) {
+      IntSet outer, inner;
+      PassUpDomain(r, dom_map,
+                   state.at(r->fused),
+                   &outer, &inner);
+      state[r->outer] = outer;
+      state[r->inner] = inner;
+    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
+      IntSet parent;
+      PassUpDomain(r, dom_map,
+                   state.at(r->rebased),
+                   &parent);
+      state[r->parent] = parent;
+    } else if (rel.as<SingletonNode>()) {
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+// Pass up bit mask with or relation.
+void PassUpBitMaskOr(const Stage& stage,
+                     std::unordered_map<IterVar, int>* p_state,
+                     bool allow_missing) {
+  auto& state = *p_state;
+  for (size_t i = stage->relations.size(); i != 0; --i) {
+    IterVarRelation rel = stage->relations[i - 1];
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      if (!state.count(s->inner) && !state.count(s->outer)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      int res = 0;
+      if (!state.count(s->parent)) res |= state[s->parent];
+      if (!state.count(s->inner)) res |= state[s->inner];
+      if (!state.count(s->outer)) res |= state[s->outer];
+      state[s->parent] = res;
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      if (!state.count(s->fused)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      if (!state.count(s->outer)) {
+        state[s->outer] = state[s->fused];
+      } else {
+        state[s->outer] |= state[s->fused];
+      }
+      if (!state.count(s->inner)) {
+        state[s->inner] = state[s->fused];
+      } else {
+        state[s->inner] |= state[s->fused];
+      }
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      if (!state.count(s->rebased)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      if (!state.count(s->parent)) {
+        state[s->parent] = state[s->rebased];
+      } else {
+        state[s->parent] |= state[s->rebased];
+      }
+    } else if (rel.as<SingletonNode>()) {
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+void PassDownBitMaskOr(const Stage& stage,
+                       std::unordered_map<IterVar, int>* p_state,
+                       bool allow_missing) {
+  auto& state = *p_state;
+  for (IterVarRelation rel : stage->relations) {
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      if (!state.count(s->parent)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      if (!state.count(s->outer)) {
+        state[s->outer] = state.at(s->parent);
+      } else {
+        state[s->outer] |= state.at(s->parent);
+      }
+      if (!state.count(s->inner)) {
+        state[s->inner] = state.at(s->parent);
+      } else {
+        state[s->inner] |= state.at(s->parent);
+      }
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      if (!state.count(s->outer) && !state.count(s->inner)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      int res = 0;
+      if (state.count(s->outer)) res |= state.at(s->outer);
+      if (state.count(s->inner)) res |= state.at(s->inner);
+      if (state.count(s->fused)) res |= state.at(s->fused);
+      state[s->fused] = res;
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      if (!state.count(s->parent)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      if (!state.count(s->rebased)) {
+        state[s->rebased] = state.at(s->parent);
+      } else {
+        state[s->rebased] |= state.at(s->parent);
+      }
+    } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
+      state[s->iter] = 0;
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+
+/*!
+ * \brief message passing to find if boundary checking on IterVar is needed.
+ * \param s The stage to be used.
+ * \param p_state The message passing state
+ *     IterVar->flag
+ */
+void PassUpBoundCheck(const Stage& s,
+                      const Map<IterVar, Range>& dom_map,
+                      std::unordered_map<IterVar, bool>* p_state,
+                      arith::Analyzer* analyzer) {
+  auto& state = *p_state;
+  for (size_t i = s->relations.size(); i != 0; --i) {
+    IterVarRelation rel = s->relations[i - 1];
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      bool outer = state.at(s->outer);
+      bool inner = state.at(s->inner);
+
+      if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
+        PrimExpr factor = dom_map.at(s->inner)->extent;
+        PrimExpr step = dom_map.at(s->outer)->extent;
+        if (outer || inner) {
+          state[s->parent] = true;
+        } else {
+          if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
+            state[s->parent] = false;
+          } else {
+            state[s->parent] = true;
+          }
+        }
+      } else {
+        state[s->parent] = true;
+      }
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      bool fused = state.at(s->fused);
+      state[s->outer] = fused;
+      state[s->inner] = fused;
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      state[s->parent] = state.at(s->rebased);
+    } else if (rel.as<SingletonNode>()) {
+      // nop
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+std::vector<PrimExpr> MakeBoundCheck(
+    const Stage& stage,
+    const Map<IterVar, Range>& dom_map,
+    const std::unordered_map<IterVar, PrimExpr>& value_map,
+    bool skip_ivar_domain,
+    const std::unordered_set<IterVar>& skip_iter) {
+  arith::Analyzer analyzer;
+
+  std::unordered_map<IterVar, bool> bound_state;
+  for (IterVar iv : stage->leaf_iter_vars) {
+    bound_state[iv] = false;
+  }
+  PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
+
+  std::vector<PrimExpr> preds;
+  std::unordered_map<const VarNode*, IntSet> iset_dmap;
+
+  // setup domain map for set analysis
+  for (const auto& kv : dom_map) {
+    iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
+  }
+
+  for (const IterVar& iv : stage->all_iter_vars) {
+    if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
+    if (bound_state.at(iv)) {
+      Range dom = dom_map.at(iv);
+      PrimExpr value = value_map.at(iv) - dom->min;
+      PrimExpr vmax = EvalSet(value, iset_dmap).max();
+      if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
+        preds.emplace_back(value < dom->extent);
+      }
+    }
+  }
+  for (const IterVar& iv : stage->op->root_iter_vars()) {
+    if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
+    Range dom = dom_map.at(iv);
+    CHECK(iv->dom.defined());
+    if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
+      PrimExpr value = value_map.at(iv) - iv->dom->min;
+      IntSet s = EvalSet(value, iset_dmap);
+      PrimExpr vmin = s.min();
+      PrimExpr vmax = s.max();
+      // The range of `value` resides in [vmin, vmax]
+      if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
+        preds.emplace_back(value >= 0);
+      }
+      if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < iv->dom->extent)) {
+        preds.emplace_back(value < iv->dom->extent);
+      }
+    }
+  }
+  return preds;
+}
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/schedule/message_passing.h b/src/top/schedule/message_passing.h
new file mode 100644 (file)
index 0000000..42b72a7
--- /dev/null
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file message_passing.h
+ * \brief Common utilities to do message passing
+ *  on the schedule hyper graph.
+ */
+#ifndef TVM_TOP_SCHEDULE_MESSAGE_PASSING_H_
+#define TVM_TOP_SCHEDULE_MESSAGE_PASSING_H_
+
+#include <tvm/expr.h>
+#include <tvm/top/schedule.h>
+#include <tvm/top/operation.h>
+#include <tvm/arith/analyzer.h>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace top {
+/*!
+ * \brief Downward inference of domain of each IterVar.
+ *  Caller set the range of the root, then the function
+ *  propagates it towards the leaves.
+ *
+ * \param stage The stage to operate on.
+ * \param p_state The state of the message passing.
+ * \param analyzer Analyzer context, storing information about bounds in p_state.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassDownDomain(
+    const Stage& stage,
+    std::unordered_map<IterVar, Range>* p_state,
+    arith::Analyzer* analyzer,
+    bool allow_missing = false);
+
+/*!
+ * \param Upward inference of index of each IterVar.
+ *  given index assignement of the leaves,
+ *
+ * \param stage The stage to operate on.
+ * \param dom_map The domain map of each iteration variable's domain.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassUpIndex(const Stage& stage,
+                 const Map<IterVar, Range>& dom_map,
+                 std::unordered_map<IterVar, PrimExpr>* p_state,
+                 bool allow_missing = false);
+
+/*!
+ * \param Downward inference of index of each IterVar.
+ *  given index assignement of roots.
+ *
+ * \param stage The stage to operate on.
+ * \param dom_map The domain map of each iteration variable's domain.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassDownIndex(const Stage& stage,
+                   const Map<IterVar, Range>& dom_map,
+                   std::unordered_map<IterVar, PrimExpr>* p_state,
+                   bool allow_missing = false);
+
+/*!
+ * \param Upward inference of domain set of each IterVar.
+ *  given domain assignment of the leaves,
+ *
+ * \param stage The stage to operate on.
+ * \param dom_map The domain map of each iteration variable's maximum domain.
+ * \param p_state The index state of each IterVar.
+ */
+void PassUpDomain(const Stage& stage,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  std::unordered_map<IterVar, IntSet>* p_state);
+
+/*!
+ * \brief Upward message passing of bitmask with or relation.
+ * \param stage The stage to operate on.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassUpBitMaskOr(const Stage& stage,
+                     std::unordered_map<IterVar, int>* p_state,
+                     bool allow_missing = false);
+
+/*!
+ * \brief Downward message passing of bitmask with or relation.
+ * \param stage The stage to operate on.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassDownBitMaskOr(const Stage& stage,
+                       std::unordered_map<IterVar, int>* p_state,
+                       bool allow_missing = false);
+
+/*!
+ * \brief Create boundary check predicates given remapped value of root
+ * \param stage The stage we operate on
+ * \param dom_map The domain map of each value.
+ * \param value_map The value map of the root iter var.
+ * \param skip_ivar_domain Whether we skip check for IterVar's original domain.
+ * \param skip_iter The set of variables to skip bound condition.
+ * \return List of predicates that we need to check.
+ */
+std::vector<PrimExpr>
+MakeBoundCheck(
+    const Stage& stage,
+    const Map<IterVar, Range>& dom_map,
+    const std::unordered_map<IterVar, PrimExpr>& value_map,
+    bool skip_ivar_domain,
+    const std::unordered_set<IterVar>& skip_iter);
+
+}  // namespace top
+}  // namespace tvm
+#endif  // TVM_TOP_SCHEDULE_MESSAGE_PASSING_H_
diff --git a/src/top/schedule/schedule_dataflow_rewrite.cc b/src/top/schedule/schedule_dataflow_rewrite.cc
new file mode 100644 (file)
index 0000000..5f9ba39
--- /dev/null
@@ -0,0 +1,885 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file schedule_dataflow_rewrite.cc
+ */
+#include <tvm/top/schedule.h>
+#include <tvm/top/operation.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/ir_pass.h>
+#include <unordered_set>
+#include "message_passing.h"
+#include "../../pass/ir_util.h"
+#include "../../arith/compute_expr.h"
+
+namespace tvm {
+namespace top {
+// find first occurance location in leaf
+template<typename T>
+size_t FindNodeRef(ArrayNode* array_node, const T& v) {
+  const Object* n = v.get();
+  for (size_t i = 0; i < array_node->data.size(); ++i) {
+    if (array_node->data[i].get() == n) return i;
+  }
+  return array_node->data.size();
+}
+
+// The replacer of cache.
+class VarReplacer : public ir::StmtExprMutator {
+ public:
+  explicit VarReplacer(
+      const std::unordered_map<const VarNode*, PrimExpr>& vsub)
+      : vsub_(vsub) {}
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = vsub_.find(op);
+    if (it != vsub_.end()) return it->second;
+    return GetRef<PrimExpr>(op);
+  }
+
+  ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
+    // Replace free variables in combiner
+    auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) {
+      return this->VisitExpr(e);
+      });
+    auto new_result = ir::UpdateArray(combiner->result, [this] (const PrimExpr& e) {
+      return this->VisitExpr(e);
+      });
+
+    if (combiner->identity_element.same_as(new_identity) &&
+        combiner->identity_element.same_as(new_result)) {
+      return combiner;
+    } else {
+      return ir::CommReducerNode::make(
+        combiner->lhs, combiner->rhs, new_result, new_identity);
+    }
+  }
+
+  PrimExpr VisitExpr_(const ir::ReduceNode* op) final {
+    PrimExpr new_e = StmtExprMutator::VisitExpr_(op);
+    const ir::ReduceNode* new_reduce = new_e.as<ir::ReduceNode>();
+    ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
+    if (op->combiner.same_as(new_combiner)) {
+      return new_e;
+    } else {
+      return ir::ReduceNode::make(
+        new_combiner,
+        new_reduce->source,
+        new_reduce->axis,
+        new_reduce->condition,
+        new_reduce->value_index);
+    }
+  }
+
+ private:
+  const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
+};
+
+PrimExpr InjectPredicate(const Array<PrimExpr>& predicates,
+                     PrimExpr body) {
+  using ir::ReduceNode;
+  using ir::SelectNode;
+  if (predicates.size() == 0) return body;
+  const ReduceNode* reduce = body.as<ReduceNode>();
+  if (reduce) {
+    auto n = make_object<ReduceNode>(*reduce);
+    n->condition = n->condition && arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr());
+    return PrimExpr(n);
+  }
+  return SelectNode::make(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()),
+                      body,
+                      make_zero(body.dtype()));
+}
+
+// Replace data flow appears in all stages given the tensor change.
+// Also update vmap if subsequent dataflow need to be replaced.
+// Need to keep an update to the date transitive closure property on the vmap by a reverse map.
+void ReplaceDataFlow(const Array<Stage>& stages,
+                     std::unordered_map<Tensor, Tensor>* vmap,
+                     std::unordered_map<Tensor, Tensor>* rvmap) {
+  for (Stage s : stages) {
+    Operation op = s->op->ReplaceInputs(s->op, *vmap);
+    if (!op.same_as(s->op)) {
+      for (int i = 0; i < op->num_outputs(); ++i) {
+        auto it = rvmap->find(s->op.output(i));
+        if (it != rvmap->end()) {
+          (*vmap)[it->second] = op.output(i);
+        } else {
+          (*vmap)[s->op.output(i)] = op.output(i);
+          (*rvmap)[op.output(i)] = s->op.output(i);
+        }
+      }
+      s->op = op;
+    }
+  }
+}
+
+inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
+  return (a->combiner.same_as(b->combiner)) &&
+         (a->source.same_as(b->source)) &&
+         (a->axis.same_as(b->axis)) &&
+         (a->condition.same_as(b->condition));
+}
+
+Tensor Schedule::cache_read(const Tensor& tensor,
+                            const std::string& scope,
+                            const Array<Operation>& readers) {
+  (*this)->InvalidateCache();
+  // create identity mapping.
+  std::ostringstream os;
+  os << tensor->op->name;
+  if (tensor->op->num_outputs() != 1) {
+    os << ".v" << tensor->value_index;
+  }
+  os << "." << scope;
+
+  std::unordered_map<Tensor, Tensor> vsub;
+  Stage s = operator[](tensor->op);
+  Tensor sugar_tensor = s->op.output(tensor->value_index);
+  Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
+      return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
+    }, os.str());
+  vsub[sugar_tensor] = cache;
+
+  std::unordered_map<Tensor, Tensor> vmap;
+  std::unordered_map<Tensor, Tensor> rvmap;
+  for (Operation op : readers) {
+    Stage s = operator[](op);
+    Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
+    CHECK(!repl_op.same_as(s->op))
+        << "Cannot find " << tensor
+        << " in the inputs of " << s->op;
+    vmap[s->op.output(0)] = repl_op.output(0);
+    rvmap[repl_op.output(0)] = s->op.output(0);
+    s->op = repl_op;
+  }
+  ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
+  ArrayNode* stages = (*this)->stages.CopyOnWrite();
+  Stage op_stage = operator[](tensor->op);
+  size_t pos = FindNodeRef(stages, op_stage);
+  Stage cache_stage = Stage(cache->op);
+  cache_stage.set_scope(scope);
+  CHECK_LT(pos, stages->data.size());
+  stages->data.insert(stages->data.begin() + pos + 1,
+                      cache_stage);
+  (*this)->stage_map.Set(cache->op, cache_stage);
+  // Update group
+  cache_stage->group = op_stage->group;
+  if (cache_stage->group.defined()) {
+    ++cache_stage->group->num_child_stages;
+  }
+  return cache;
+}
+
+template<typename OpType>
+void PrepareAxisMapping(Stage orig_stage,
+                        OpType* op,
+                        std::unordered_set<IterVar>* p_red_axis,
+                        Array<IterVar>* p_new_axis,
+                        std::unordered_map<IterVar, Range>* p_dom_map,
+                        std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
+                        std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
+                        std::vector<PrimExpr>* p_predicates) {
+  auto& red_axis = *p_red_axis;
+  auto& new_axis = *p_new_axis;
+  auto& dom_map = *p_dom_map;
+  auto& vsub = *p_vsub;
+  auto& vsub2newvar = *p_vsub2newvar;
+  auto& predicates = *p_predicates;
+  arith::Analyzer analyzer;
+
+  for (IterVar iv : op->reduce_axis) {
+    red_axis.insert(iv);
+  }
+  for (IterVar iv : op->axis) {
+    dom_map[iv] = iv->dom;
+    analyzer.Bind(iv->var, iv->dom);
+  }
+  top::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
+  {
+    // The source->cache
+    std::unordered_map<IterVar, PrimExpr> value_map;
+    for (IterVar iv : orig_stage->leaf_iter_vars) {
+      if (red_axis.count(iv)) continue;
+      CHECK_EQ(iv->iter_type, kDataPar)
+          << "Can only relayout with in data parallel dimensions";
+      Range dom = dom_map.at(iv);
+      IterVar new_iv = IterVarNode::make(
+          dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
+      new_axis.push_back(new_iv);
+      if (is_one(dom->min)) {
+        value_map[iv] = dom->min;
+      } else {
+        value_map[iv] = iv->var;
+        vsub2newvar[iv->var.get()] = new_iv->var;
+      }
+    }
+    // skip reduction iteration.
+    std::unordered_set<IterVar> skip_bound_check;
+    for (IterVar iv : op->reduce_axis) {
+      skip_bound_check.insert(iv);
+    }
+    PassUpIndex(orig_stage, dom_map, &value_map, true);
+    predicates = MakeBoundCheck(
+        orig_stage, dom_map, value_map, true, skip_bound_check);
+    // The root axis
+    for (IterVar iv : op->axis) {
+      if (value_map.count(iv)) {
+        vsub[iv->var.get()] = value_map.at(iv);
+      }  // to handle tensor axis
+    }
+  }
+}
+
+Array<Tensor> ReplaceOriginalOp(Schedule sch,
+                                Stage orig_stage,
+                                const std::string& scope,
+                                Operation cache_op,
+                                Operation orig_new_op,
+                                size_t tensor_size) {
+  Array<Tensor> cache_tensor_list;
+  for (size_t i = 0; i < tensor_size; i++) {
+    Tensor cache_tensor = cache_op.output(i);
+    cache_tensor_list.push_back(cache_tensor);
+  }
+  // The replace of the dataflow
+  std::unordered_map<Tensor, Tensor> vmap;
+  std::unordered_map<Tensor, Tensor> rvmap;
+  vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
+  rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
+  for (size_t i = 0; i < tensor_size; i++) {
+    vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
+    rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
+  }
+  ReplaceDataFlow(sch->stages, &vmap, &rvmap);
+  // mutate orig stage
+  orig_stage->op = orig_new_op;
+  orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
+  orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
+  orig_stage->relations = Array<IterVarRelation>();
+  // create schedule for new cached stage.
+  ArrayNode* stages = sch->stages.CopyOnWrite();
+  size_t pos = FindNodeRef(stages, orig_stage);
+  Stage cache_stage = Stage(cache_op);
+  cache_stage.set_scope(scope);
+  CHECK_LT(pos, stages->data.size());
+  stages->data.insert(stages->data.begin() + pos,
+                      cache_stage);
+  sch->stage_map.Set(cache_op, cache_stage);
+  // Update group
+  cache_stage->group = orig_stage->group;
+  if (cache_stage->group.defined()) {
+    ++cache_stage->group->num_child_stages;
+  }
+  return cache_tensor_list;
+}
+
+
+// Cache write and relayout the data according to loop pattern
+Array<Tensor> CacheWriteWithReLayout(Schedule sch,
+                                     const Array<Tensor>& tensor_array,
+                                     const std::string& scope) {
+  size_t tensor_size = tensor_array.size();
+  sch->InvalidateCache();
+  Tensor tensor = tensor_array[0];
+  Stage orig_stage = sch[tensor->op];
+  const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
+
+  std::unordered_set<IterVar> red_axis;
+  Array<IterVar> new_axis;
+  std::unordered_map<IterVar, Range> dom_map;
+
+  std::unordered_map<const VarNode*, PrimExpr> vsub;
+  std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
+  std::vector<PrimExpr> predicates;
+
+  PrepareAxisMapping(orig_stage, compute,
+    &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
+
+  PrimExpr body;
+  Array<PrimExpr> body_list;
+  const ir::ReduceNode* first_reduce = nullptr;
+  for (auto cbody : compute->body) {
+    body = VarReplacer(vsub)(cbody);
+    body = InjectPredicate(predicates, body);
+    body = VarReplacer(vsub2newvar)(body);
+    // Reduce nodes in ONE computeOp must be the same except value_index
+    // This is right only if the original body ensures Reduce nodes are the same
+    if (body->IsInstance<ir::ReduceNode>()) {
+      const ir::ReduceNode* reduce_body = body.as<ir::ReduceNode>();
+      if (first_reduce != nullptr) {
+        CHECK(ReduceEqual(reduce_body, first_reduce));
+        body = ir::ReduceNode::make(first_reduce->combiner,
+                                first_reduce->source,
+                                first_reduce->axis,
+                                first_reduce->condition,
+                                reduce_body->value_index);
+      } else {
+        first_reduce = reduce_body;
+      }
+    } else {
+      CHECK(first_reduce == nullptr)
+        << "cannot mix reduce and other node in ONE compute bodys";
+    }
+    body_list.push_back(body);
+  }
+  // The reader args
+  Array<PrimExpr> args;
+  {
+    // cache->compute
+    std::unordered_map<IterVar, PrimExpr> value_map;
+    for (IterVar iv : compute->axis) {
+      value_map[iv] = iv->var;
+    }
+    top::PassDownIndex(orig_stage, dom_map, &value_map, true);
+    for (IterVar iv : orig_stage->leaf_iter_vars) {
+      if (red_axis.count(iv)) continue;
+      args.push_back(value_map.at(iv));
+    }
+  }
+  Operation cache_op = ComputeOpNode::make(
+      compute->name + "." + scope, compute->tag, compute->attrs,
+      new_axis, body_list);
+
+  Array<PrimExpr> cache_expr_list;
+  for (size_t i = 0; i < tensor_size; i++) {
+    Tensor cache_tensor = cache_op.output(i);
+    cache_expr_list.push_back(cache_tensor(args));
+  }
+  Operation orig_new_op = ComputeOpNode::make(
+      compute->name, compute->tag, compute->attrs,
+      compute->axis, cache_expr_list);
+  return ReplaceOriginalOp(sch, orig_stage, scope,
+    cache_op, orig_new_op, tensor_size);
+}
+
+
+// for tensor compute op
+Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
+                                           const Array<Tensor>& tensor_array,
+                                           const std::string& scope) {
+  size_t tensor_size = tensor_array.size();
+  sch->InvalidateCache();
+  Tensor tensor = tensor_array[0];
+  Stage orig_stage = sch[tensor->op];
+  const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
+  CHECK_EQ(tensor_op->num_outputs(), 1)
+      << "cache write only support single output tensor_compute_op";
+
+  std::unordered_set<IterVar> red_axis;
+  Array<IterVar> new_axis;
+  std::unordered_map<IterVar, Range> dom_map;
+
+  std::unordered_map<const VarNode*, PrimExpr> vsub;
+  std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
+  std::vector<PrimExpr> predicates;
+
+  PrepareAxisMapping(orig_stage, tensor_op,
+    &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
+
+
+  for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
+    IterVar iv = tensor_op->axis[i];
+    IterVar new_iv = IterVarNode::make(
+      iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
+    new_axis.push_back(new_iv);
+  }
+  Array<Region> new_regions;
+  for (Region old_region : tensor_op->input_regions) {
+    Region region;
+    for (Range r : old_region) {
+      PrimExpr min = VarReplacer(vsub2newvar)(r->min);
+      PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
+      region.push_back(Range::make_by_min_extent(min, extent));
+    }
+    new_regions.push_back(region);
+  }
+
+  Array<PrimExpr> new_scalar_inputs;
+  for (PrimExpr old_input : tensor_op->scalar_inputs) {
+    new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
+  }
+
+  Operation cache_op = TensorComputeOpNode::make(
+      tensor_op->name + "." + scope, tensor_op->tag, new_axis,
+      tensor_op->reduce_axis, tensor_op->schedulable_ndim,
+      tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);
+
+  // axis will be used in generating compute op
+  Array<IterVar> compute_axis = tensor_op->axis;
+  for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
+    IterVar iv = tensor_op->axis[i];
+    IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
+    compute_axis.Set(i, aiv);
+  }
+
+  // The reader args
+  Array<PrimExpr> args;
+  {
+    // cache->compute
+    std::unordered_map<IterVar, PrimExpr> value_map;
+    for (IterVar iv : compute_axis) {
+      value_map[iv] = iv->var;
+    }
+    PassDownIndex(orig_stage, dom_map, &value_map, true);
+    for (IterVar iv : orig_stage->leaf_iter_vars) {
+      if (red_axis.count(iv)) continue;
+      args.push_back(value_map.at(iv));
+    }
+    // tensorized region axis
+    for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
+      IterVar iv = compute_axis[i];
+      args.push_back(value_map.at(iv));
+    }
+  }
+
+  Array<PrimExpr> cache_expr_list;
+  for (size_t i = 0; i < tensor_size; i++) {
+    Tensor cache_tensor = cache_op.output(i);
+    cache_expr_list.push_back(cache_tensor(args));
+  }
+  Operation orig_new_op = ComputeOpNode::make(
+      tensor_op->name, tensor_op->tag, {},
+      compute_axis, cache_expr_list);
+  return ReplaceOriginalOp(sch, orig_stage, scope,
+    cache_op, orig_new_op, tensor_size);
+}
+
+
+Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
+                             const std::string& scope) {
+  (*this)->InvalidateCache();
+  CHECK(tensor_array.size() > 0)
+      << "size of tensor_array must be greater than 0";
+  Tensor tensor = tensor_array[0];
+  Stage orig_stage = operator[](tensor->op);
+  const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
+  CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
+      << "size of input tensor list must be same as number of stage outputs";
+  for (size_t i = 1; i < tensor_array.size(); i++) {
+    Stage tmp_stage = operator[](tensor_array[i]->op);
+    CHECK(orig_stage.same_as(tmp_stage))
+        << "Input tensor list must be generated by ONE computeOp";
+  }
+  return CacheWriteWithReLayout(*this, tensor_array, scope);
+}
+
+
+Tensor Schedule::cache_write(const Tensor& tensor,
+                             const std::string& scope) {
+  // support original compute and tensor compute both
+  (*this)->InvalidateCache();
+  if (tensor->op.as<ComputeOpNode>()) {
+    return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
+  } else if (tensor->op.as<TensorComputeOpNode>()) {
+    return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
+  } else {
+    LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
+    return Tensor();
+  }
+}
+
+
+void RebaseNonZeroMinLoop(const Schedule& sch) {
+  std::unordered_map<IterVar, IterVar> rebase_map;
+  for (Stage s : sch->stages) {
+    if (s->attach_type == kInlinedAlready) continue;
+
+    auto root_iter_vars = s->op->root_iter_vars();
+    ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
+    for (IterVar iv : root_iter_vars) {
+      size_t idx = FindNodeRef(leaf_vars, iv);
+      auto it  = s->iter_var_attrs.find(iv);
+      // don;t need to rebase path that are binded.
+      if (it != s->iter_var_attrs.end() &&
+          (*it).second->bind_thread.defined()) {
+        continue;
+      }
+      if (idx < leaf_vars->data.size()) {
+        // insert rebase
+        IterVar rebased = IterVarNode::make(
+            Range(), iv->var.copy_with_suffix(""), iv->iter_type);
+        s->relations.push_back(RebaseNode::make(iv, rebased));
+        if (s->iter_var_attrs.count(iv)) {
+          s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
+        }
+        leaf_vars->data[idx] = rebased;
+        rebase_map[iv] = rebased;
+      }
+    }
+  }
+  // remap the parent relation
+  for (Stage s : sch->stages) {
+    if (s->attach_type != kScope) continue;
+    if (rebase_map.count(s->attach_ivar)) {
+      s->attach_ivar = rebase_map.at(s->attach_ivar);
+    }
+  }
+  for (Stage s : sch->groups) {
+    if (s->attach_type != kScope) continue;
+    if (rebase_map.count(s->attach_ivar)) {
+      s->attach_ivar = rebase_map.at(s->attach_ivar);
+    }
+  }
+}
+
+void InjectInline(ScheduleNode* sch) {
+  sch->InvalidateCache();
+
+  std::vector<Array<PrimExpr> > new_body(sch->stages.size());
+  std::vector<bool> changed(sch->stages.size(), false);
+  std::vector<Stmt> new_hybrid_body(sch->stages.size());
+  std::vector<bool> hybrid_changed(sch->stages.size(), false);
+  // inline all the ops
+  for (size_t i = sch->stages.size(); i != 0; --i) {
+    Stage stage = sch->stages[i - 1];
+    if (stage->attach_type == kInline) {
+      stage->attach_type = kInlinedAlready;
+      Array<Var> args;
+      PrimExpr body;
+      {
+        // setup args
+        const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
+        CHECK(compute)
+            << "can only inline compute op";
+        for (auto iv : compute->axis) {
+          args.push_back(iv->var);
+        }
+        CHECK_EQ(compute->body.size(), 1U)
+            << "can only inline compute op with 1 output";
+        body = compute->body[0];
+      }
+      for (size_t j = i; j < sch->stages.size(); ++j) {
+        Stage s = sch->stages[j];
+        const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
+        const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
+        if (compute) {
+          if (!new_body[j].size()) {
+            new_body[j] = compute->body;
+          }
+          if (new_body[j][0]->IsInstance<ir::ReduceNode>()) {
+            // specially handle reduction inline for multiplre reductions.
+            const ir::ReduceNode* reduce = new_body[j][0].as<ir::ReduceNode>();
+            for (size_t k = 1; k < new_body[j].size(); ++k) {
+              const ir::ReduceNode* reduce_ = new_body[j][k].as<ir::ReduceNode>();
+              CHECK(reduce_);
+              CHECK(ReduceEqual(reduce_, reduce))
+                  << "The Reduce inputs of ComputeOp should "
+                  << "have the same attribute except value_index";
+            }
+            PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]),
+                                        stage->op, args, body).as<ir::EvaluateNode>()->value;
+            if (!new_value.same_as(new_body[j][0])) {
+              changed[j] = true;
+              const ir::ReduceNode* r = new_value.as<ir::ReduceNode>();
+              CHECK_EQ(new_body[j].size(), r->source.size());
+              CHECK(r != nullptr);
+              for (size_t k = 0; k < new_body[j].size(); ++k) {
+                auto n = make_object<ir::ReduceNode>(*r);
+                n->value_index = static_cast<int>(k);
+                n->dtype = r->source[k].dtype();
+                new_body[j].Set(k, PrimExpr(n));
+              }
+            }
+          } else {
+            for (size_t k = 0; k < new_body[j].size(); ++k) {
+              PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]),
+                                          stage->op, args, body).as<ir::EvaluateNode>()->value;
+              if (!new_value.same_as(new_body[j][k])) {
+                new_body[j].Set(k, new_value);
+                changed[j] = true;
+              }
+            }
+          }
+        } else if (hybrid) {
+          if (!new_hybrid_body[j].defined()) {
+            new_hybrid_body[j] = hybrid->body;
+          }
+          Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body);
+          if (!new_stmt.same_as(new_hybrid_body[j])) {
+            new_hybrid_body[j] = new_stmt;
+            hybrid_changed[j] = true;
+          }
+        }
+      }
+    }
+  }
+  std::unordered_map<Tensor, Tensor> repl;
+  // rewrite dataflow
+  for (size_t i = 0; i < sch->stages.size(); ++i) {
+    Stage s = sch->stages[i];
+    if (s->attach_type == kInlinedAlready) continue;
+    if (new_body[i].size()) {
+      // Logics from ReplaceDataFlow
+      const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
+      CHECK(compute);
+      Operation op = s->op;
+      if (changed[i]) {
+        op = ComputeOpNode::make(
+            compute->name, compute->tag, compute->attrs,
+            compute->axis, new_body[i]);
+      }
+      op = op->ReplaceInputs(op, repl);
+      if (!op.same_as(s->op)) {
+        for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
+          repl[s->op.output(idx)] = op.output(idx);
+        }
+        s->op = op;
+      }
+    } else if (hybrid_changed[i]) {
+      const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
+      CHECK(hybrid);
+      Operation op = HybridOpNode::make(
+              hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
+              hybrid->outputs, new_hybrid_body[i]);
+      op = op->ReplaceInputs(op, repl);
+      for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
+        repl[s->op.output(idx)] = op.output(idx);
+      }
+      s->op = op;
+    } else {
+      Operation op = s->op->ReplaceInputs(s->op, repl);
+      if (!op.same_as(s->op)) {
+        for (int j = 0; j < op->num_outputs(); ++j) {
+          repl[s->op.output(j)] = op.output(j);
+        }
+        s->op = op;
+      }
+    }
+  }
+}
+
+Schedule Schedule::normalize() {
+  Schedule sn = copy();
+  InjectInline(sn.operator->());
+  RebaseNonZeroMinLoop(sn);
+  return sn;
+}
+
+// Handle reduction factor.
+Array<Tensor> Schedule::rfactor(const Tensor& tensor,
+                                const IterVar& axis,
+                                int factor_axis) {
+  (*this)->InvalidateCache();
+  using ir::ReduceNode;
+  CHECK_EQ(axis->iter_type, kCommReduce)
+      << "Can only factor reduction axis";
+  Stage reduce_stage = operator[](tensor->op);
+  const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
+  CHECK(compute_op) << "Can only factor ComputeOp";
+  ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
+  {
+    size_t axis_pos = FindNodeRef(leaf_vars, axis);
+    CHECK_NE(axis_pos, leaf_vars->data.size())
+        << "Cannot find IterVar " << axis << " in leaf iter vars";
+  }
+  // Find touched reduction axis.
+  std::unordered_map<IterVar, int> touch_map;
+  touch_map[axis] = 1;
+  top::PassUpBitMaskOr(reduce_stage, &touch_map, true);
+  top::PassDownBitMaskOr(reduce_stage, &touch_map, true);
+  // skip reduction iteration.
+  std::unordered_set<IterVar> skip_bound_check;
+  // Verify normal axis are not touched.
+  for (IterVar iv : compute_op->axis) {
+    CHECK(!touch_map.count(iv))
+        << "Factor axis touches normal axis.";
+    skip_bound_check.insert(iv);
+  }
+  // get analyzer.
+  arith::Analyzer analyzer;
+  // Get the replace index
+  std::unordered_map<IterVar, Range> dom_map;
+  std::unordered_map<IterVar, PrimExpr> value_map;
+  for (IterVar iv : compute_op->reduce_axis) {
+    if (touch_map.count(iv)) {
+      dom_map[iv] = iv->dom;
+    } else {
+      skip_bound_check.insert(iv);
+    }
+    analyzer.Bind(iv->var, iv->dom);
+  }
+  top::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
+  for (IterVar iv : reduce_stage->leaf_iter_vars) {
+    if (touch_map.count(iv)) {
+      Range dom = dom_map.at(iv);
+      if (is_one(dom->extent)) {
+        value_map[iv] = dom->min;
+      } else {
+        value_map[iv] = iv->var;
+      }
+    }
+  }
+  top::PassUpIndex(reduce_stage, dom_map, &value_map, true);
+  std::vector<PrimExpr> predicates = MakeBoundCheck(
+      reduce_stage, dom_map, value_map, true, skip_bound_check);
+
+  // Get the factored op node.
+  const int factor_axis_pos = \
+      factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
+  CHECK_LE(factor_axis_pos, compute_op->axis.size());
+  auto n = make_object<ComputeOpNode>();
+  n->name = compute_op->name + ".rf";
+  {
+    // axis relacement.
+    auto iv_node = make_object<IterVarNode>();
+    iv_node->dom = dom_map.at(axis);
+    CHECK(is_zero(iv_node->dom->min))
+        << "Can only factor reduction domain starting from 0";
+    iv_node->var = axis->var;
+    iv_node->iter_type = kDataPar;
+
+    const int size = compute_op->axis.size();
+    for (int idx = 0; idx < size; ++idx) {
+      if (factor_axis_pos == idx) {
+        n->axis.push_back(IterVar(iv_node));
+      }
+      n->axis.push_back(compute_op->axis[idx]);
+    }
+    if (factor_axis_pos == size) {
+      n->axis.push_back(IterVar(iv_node));
+    }
+  }
+  // predicate generation, copy not touched axis.
+  int idx = tensor->value_index;
+  const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
+  CHECK(reduce) << "Can only rfactor non-inline reductions";
+  predicates.push_back(reduce->condition);
+  PrimExpr predicate = likely(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()));
+
+  std::unordered_map<const VarNode*, PrimExpr> vsub;
+
+  for (IterVar iv : compute_op->reduce_axis) {
+    if (!touch_map.count(iv)) {
+      n->reduce_axis.push_back(iv);
+    } else {
+      CHECK(value_map.count(iv));
+      PrimExpr index = value_map.at(iv);
+      vsub[iv->var.get()] = index;
+    }
+  }
+
+  // Copy touched axis.
+  for (IterVar iv : reduce_stage->leaf_iter_vars) {
+    if (touch_map.count(iv) && !iv.same_as(axis)) {
+      CHECK_EQ(iv->iter_type, kCommReduce);
+      auto ncpy = make_object<IterVarNode>(*iv.operator->());
+      ncpy->dom = dom_map.at(iv);
+      n->reduce_axis.push_back(IterVar(ncpy));
+    }
+  }
+  VarReplacer replacer(vsub);
+  Array<PrimExpr> new_source = ir::UpdateArray(reduce->source,
+    [&replacer] (const PrimExpr& e) { return replacer(e); });
+
+  PrimExpr new_pred = replacer(predicate);
+
+  std::vector<PrimExpr> body;
+  for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
+    body.emplace_back(ReduceNode::make(reduce->combiner,
+                                   new_source,
+                                   n->reduce_axis,
+                                   new_pred,
+                                   idx));
+  }
+  n->body = Array<PrimExpr>(body);
+  // refresh relations, keep the un-touched relations.
+  Array<IterVarRelation> rels;
+  for (IterVarRelation rel : reduce_stage->relations) {
+    bool touched = false;
+    if (const SplitNode* r = rel.as<SplitNode>()) {
+      if (touch_map.count(r->parent)) touched = true;
+    } else if (const FuseNode* r = rel.as<FuseNode>()) {
+      if (touch_map.count(r->fused)) touched = true;
+    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
+      if (touch_map.count(r->parent)) touched = true;
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+    if (!touched) {
+      rels.push_back(rel);
+    }
+  }
+  // initialize the factored stage.
+  Operation factor_op(n);
+  ArrayNode* stages = (*this)->stages.CopyOnWrite();
+  size_t stage_pos = FindNodeRef(stages, reduce_stage);
+  Stage factor_stage = Stage(factor_op);
+  factor_stage->relations = rels;
+  CHECK_LT(stage_pos, stages->data.size());
+  stages->data.insert(stages->data.begin() + stage_pos,
+                      factor_stage);
+  (*this)->stage_map.Set(factor_op, factor_stage);
+  factor_stage->group = reduce_stage->group;
+  if (factor_stage->group.defined()) {
+    ++factor_stage->group->num_child_stages;
+  }
+  // Replace the old reduction.
+  IterVar repl_red_axis = reduce_axis(
+      dom_map.at(axis), axis->var->name_hint + ".v");
+  Array<Tensor> factor_tensors;
+  Array<Tensor> old_tensors;
+  int size = factor_op->num_outputs();
+  for (int idx = 0; idx < size; ++idx) {
+    factor_tensors.push_back(factor_op.output(idx));
+    old_tensors.push_back(reduce_stage->op.output(idx));
+  }
+  Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
+    [&](const Array<Var>& i) {
+      Array<PrimExpr> indices;
+      const int idx_size = static_cast<int>(i.size());
+      for (int idx = 0; idx < idx_size; ++idx) {
+        if (factor_axis_pos == idx) {
+          indices.push_back(repl_red_axis->var);
+        }
+        indices.push_back(i[idx]);
+      }
+      if (factor_axis_pos == idx_size) {
+          indices.push_back(repl_red_axis->var);
+      }
+      Array<PrimExpr> factor_exprs;
+      for (int idx = 0; idx < size; ++idx) {
+        factor_exprs.push_back(factor_tensors[idx](indices));
+      }
+      Array<PrimExpr> reductions;
+      Array<IterVar> axis = {repl_red_axis};
+      PrimExpr cond = const_true();
+      for (int idx = 0; idx < size; ++idx) {
+        reductions.push_back(ReduceNode::make(reduce->combiner,
+          factor_exprs, axis, cond, idx));
+      }
+      return reductions;
+    }, reduce_stage->op->name + ".repl");
+
+  std::unordered_map<Tensor, Tensor> vmap;
+  std::unordered_map<Tensor, Tensor> rvmap;
+  for (int idx = 0; idx < size; ++idx) {
+    vmap[old_tensors[idx]] = repl_tensors[idx];
+    rvmap[repl_tensors[idx]] = old_tensors[idx];
+  }
+  ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
+  // revamp the reduction stage.
+  reduce_stage->op = repl_tensors[0]->op;
+  reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
+  reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
+  reduce_stage->relations = Array<IterVarRelation>();
+  return factor_tensors;
+}
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/schedule/schedule_lang.cc b/src/top/schedule/schedule_lang.cc
new file mode 100644 (file)
index 0000000..5523530
--- /dev/null
@@ -0,0 +1,852 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file schedule_lang.cc
+ */
+#include <tvm/top/schedule.h>
+#include <tvm/top/operation.h>
+#include <unordered_set>
+#include "graph.h"
+
+namespace tvm {
+namespace top {
+
+// find first occurance location in leaf
+template<typename T>
+size_t FindNodeRef(ArrayNode* array_node, const T& v) {
+  const Object* n = v.get();
+  for (size_t i = 0; i < array_node->data.size(); ++i) {
+    if (array_node->data[i].get() == n) return i;
+  }
+  return array_node->data.size();
+}
+
+size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) {
+  size_t pos = FindNodeRef(leaf_vars, v);
+  if (pos < leaf_vars->data.size()) return pos;
+
+  if (FindNodeRef(all_vars, v) < all_vars->data.size()) {
+    LOG(FATAL) << "Operate on iter var " << v
+               << "that has already been split";
+  } else {
+    LOG(FATAL) << "Operate on iter var " << v
+               << "that is not part of the schedule";
+  }
+  return 0;
+}
+
+void Split(StageNode* self,
+           IterVar parent,
+           PrimExpr factor,
+           PrimExpr nparts,
+           IterVar* p_outer,
+           IterVar* p_inner) {
+  // Check if split is valid.
+  CHECK(parent->iter_type == kDataPar ||
+        parent->iter_type == kCommReduce ||
+        parent->iter_type == kOrdered)
+      << "Cannot split on " << IterVarType2String(parent->iter_type);
+  IterVar outer = IterVarNode::make(
+      Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
+  IterVar inner = IterVarNode::make(
+      Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
+  *p_outer = outer;
+  *p_inner = inner;
+  // The splits
+  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+  size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
+  self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts));
+  // add vars to all vars
+  all_vars->data.push_back(outer);
+  all_vars->data.push_back(inner);
+  // replace the position.
+  leaf_vars->data.erase(leaf_vars->data.begin() + pos);
+  leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner);
+  leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer);
+}
+
+Stage::Stage(Operation op) {
+  auto n = make_object<StageNode>();
+  n->op = op;
+  n->origin_op = op;
+  n->all_iter_vars = op->root_iter_vars();
+  // remove opaque var from leaf.
+  Array<IterVar> clean;
+  for (IterVar iv : n->all_iter_vars) {
+    if (iv->iter_type != kOpaque) clean.push_back(iv);
+  }
+  if (clean.size() == n->all_iter_vars.size()) {
+    n->leaf_iter_vars = n->all_iter_vars;
+  } else {
+    n->leaf_iter_vars = clean;
+  }
+  data_ = std::move(n);
+}
+
+bool Stage::is_scheduled() const {
+  const StageNode* n = operator->();
+  return !(n->relations.empty() && n->attach_type == kGroupRoot &&
+           n->all_iter_vars.same_as(n->leaf_iter_vars));
+}
+
+Stage Stage::GetAttachSpec() const {
+  Stage attach_spec = *this;
+  while (attach_spec->attach_type == kGroupRoot &&
+         attach_spec->group.defined()) {
+    attach_spec = attach_spec->group;
+  }
+  return attach_spec;
+}
+
+Stage& Stage::set_scope(std::string scope) {  // NOLINT(*)
+  (*this)->scope = scope;
+  return *this;
+}
+
+Stage& Stage::compute_at(Stage parent, IterVar scope) {   // NOLINT(*)
+  CHECK_NE((*this)->attach_type, kScanUpdate)
+      << "Cannot specify compute_at for scan updates";
+  // Group constraint checking.
+  Stage group = (*this)->group;
+  if (group.defined()) {
+    Stage pg = parent->group;
+    while (pg.defined() && !pg.same_as(group)) {
+      pg = pg->group;
+    }
+    CHECK(pg.same_as(group))
+        << "Can only assign compute_at to stages within the same group";
+  }
+
+  (*this)->attach_type = kScope;
+  (*this)->attach_ivar = scope;
+  (*this)->attach_stage = parent;
+  bool found = false;
+  for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
+    if (scope == parent->leaf_iter_vars[i]) {
+      found = true; break;
+    }
+  }
+  CHECK(found)
+      << "Cannot find the axis " << scope
+      << " in parent's leaf_iter_vars"
+      << " parent=" << parent;
+  return *this;
+}
+
+Stage& Stage::compute_inline() {   // NOLINT(*)
+  CHECK_NE((*this)->attach_type, kScanUpdate)
+      << "Cannot specify compute_at for scan updates";
+  (*this)->attach_type = kInline;
+  return *this;
+}
+
+Stage& Stage::compute_root() {   // NOLINT(*)
+  CHECK_NE((*this)->attach_type, kScanUpdate)
+      << "Cannot specify compute_at for scan updates";
+  (*this)->attach_type = kGroupRoot;
+  return *this;
+}
+
+Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) {   // NOLINT(*)
+  StageNode* self = operator->();
+  CHECK(ivar->iter_type == kDataPar ||
+        ivar->iter_type == kCommReduce)
+      << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread";
+  CHECK(thread_ivar->iter_type == kThreadIndex)
+      << "Cannot rebase by " << IterVarType2String(ivar->iter_type)
+      << ", only thread axis is allowed so far";
+  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+  FindLeafVar(all_vars, leaf_vars, ivar);
+
+  auto it = self->iter_var_attrs.find(ivar);
+  ObjectPtr<IterVarAttrNode> n;
+  if (it != self->iter_var_attrs.end()) {
+    n = make_object<IterVarAttrNode>(*(*it).second.operator->());
+    if (n->bind_thread.defined() &&
+        !n->bind_thread.same_as(thread_ivar)) {
+      LOG(WARNING) << "Axis " << ivar
+                   << " is already bind to another thread " << n->bind_thread;
+    }
+  } else {
+    n = make_object<IterVarAttrNode>();
+  }
+  n->bind_thread = thread_ivar;
+  self->iter_var_attrs.Set(ivar, IterVarAttr(n));
+  return *this;
+}
+
+Stage& Stage::env_threads(Array<IterVar> threads) {
+  StageNode* self = operator->();
+  CHECK(self->op.defined() && self->op.as<ScanOpNode>())
+      << "env_threads is only valid for composite ops such as ScanOp";
+  CHECK_EQ(self->env_threads.size(), 0U)
+      << "Already set env_threads";
+  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+  std::vector<ObjectRef> temp;
+  for (IterVar iv : threads) {
+    temp.push_back(iv);
+  }
+  leaf_vars->data.insert(
+      leaf_vars->data.begin(), temp.begin(), temp.end());
+  all_vars->data.insert(
+      all_vars->data.end(), temp.begin(), temp.end());
+  self->env_threads = threads;
+  return *this;
+}
+
+Stage& Stage::set_store_predicate(PrimExpr predicate) {
+  StageNode* self = operator->();
+  self->store_predicate = predicate;
+  return *this;
+}
+
+Stage& Stage::split(
+    IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) {  // NOLINT(*)
+  Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
+  return *this;
+}
+
+Stage& Stage::split_by_nparts(
+    IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
+  Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
+  return *this;
+}
+
+Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) {  // NOLINT(*)
+  StageNode* self = operator->();
+  CHECK(outer->iter_type == kDataPar ||
+        outer->iter_type == kCommReduce ||
+        outer->iter_type == kOrdered)
+      << "Cannot fuse " << IterVarType2String(outer->iter_type);
+  CHECK(inner->iter_type == kDataPar ||
+        inner->iter_type == kCommReduce ||
+        inner->iter_type == kOrdered)
+      << "Cannot fuse " << IterVarType2String(inner->iter_type);
+
+  IterVarType iter_type = outer->iter_type;
+  if (inner->iter_type > iter_type) iter_type = inner->iter_type;
+  std::string fused_name =
+      outer->var->name_hint + "." + inner->var->name_hint + ".fused";
+
+  IterVar fused = IterVarNode::make(
+      Range(), Var(fused_name, outer->var.dtype()), iter_type);
+
+  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+
+  size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner);
+  size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer);
+  if (pos_inner + 1 == pos_outer) {
+    std::swap(outer, inner);
+    std::swap(pos_inner, pos_outer);
+  }
+  self->relations.push_back(FuseNode::make(outer, inner, fused));
+  all_vars->data.push_back(fused);
+  CHECK_EQ(pos_inner, pos_outer + 1)
+      << "Can only fuse iterations that are consecutive between each other";
+  leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
+                        leaf_vars->data.begin() + pos_inner + 1);
+  leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
+                         fused);
+  *p_target = fused;
+  return *this;
+}
+
+Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) {  // NOLINT(*)
+  if (axes.size() != 0) {
+    IterVar fused = axes[0];
+    for (size_t i = 1; i < axes.size(); ++i) {
+      this->fuse(fused, axes[i], &fused);
+    }
+    *p_target = std::move(fused);
+  } else {
+    StageNode* self = operator->();
+    // special handle fuse empty array.
+    // insert at the outer most loop
+    IterVar singleton = IterVarNode::make(
+        Range::make_by_min_extent(0, 1),
+        Var("singleton", DataType::Int(32)), kDataPar);
+    self->relations.push_back(SingletonNode::make(singleton));
+    ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+    ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+    all_vars->data.push_back(singleton);
+    leaf_vars->data.insert(leaf_vars->data.begin(), singleton);
+    *p_target = singleton;
+  }
+  return *this;
+}
+
+Stage& Stage::reorder(const Array<IterVar>& order) {  // NOLINT(*)
+  std::unordered_set<IterVar> seen_var;
+  StageNode* self = operator->();
+  for (IterVar iv : order) {
+    CHECK(iv->iter_type == kDataPar ||
+          iv->iter_type == kCommReduce ||
+          iv->iter_type == kThreadIndex)
+        << "Cannot reorder IterVar("
+        << IterVarType2String(iv->iter_type) << ")";
+
+    CHECK_EQ(seen_var.count(iv), 0)
+        << "Same axis can not appear more than once " << iv;
+    seen_var.insert(iv);
+  }
+  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+  std::vector<size_t> pos;
+
+  for (size_t i = 0; i < order.size(); ++i) {
+    pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
+  }
+  std::vector<ObjectRef> temp;
+  for (size_t i = 0; i < pos.size(); ++i) {
+    temp.emplace_back(leaf_vars->data[pos[i]]);
+  }
+  std::sort(pos.begin(), pos.end());
+  for (size_t i = 0; i < pos.size(); ++i) {
+    leaf_vars->data[pos[i]] = temp[i];
+  }
+  return *this;
+}
+
+Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
+                   PrimExpr x_factor, PrimExpr y_factor,
+                   IterVar* p_x_outer, IterVar* p_y_outer,
+                   IterVar* p_x_inner, IterVar* p_y_inner) {
+  split(x_parent, x_factor, p_x_outer, p_x_inner);
+  split(y_parent, y_factor, p_y_outer, p_y_inner);
+  reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
+  return *this;
+}
+
+template<typename FUpdate>
+inline void UpdateIterVarAttr(StageNode* self,
+                              IterVar var,
+                              FUpdate fupdate,
+                              bool need_leaf = true) {
+  if (need_leaf) {
+    ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+    ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+    FindLeafVar(all_vars, leaf_vars, var);
+  }
+  auto it = self->iter_var_attrs.find(var);
+  ObjectPtr<IterVarAttrNode> n;
+  if (it != self->iter_var_attrs.end()) {
+    n = make_object<IterVarAttrNode>(*(*it).second.operator->());
+  } else {
+    n = make_object<IterVarAttrNode>();
+  }
+  fupdate(n.get());
+  self->iter_var_attrs.Set(var, IterVarAttr(n));
+}
+
+inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
+  UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) {
+      n->iter_type = iter_type;
+    });
+}
+
+Stage& Stage::vectorize(IterVar var) {   // NOLINT(*)
+  CHECK(var->iter_type == kDataPar ||
+        var->iter_type == kOpaque ||
+        var->iter_type == kUnrolled ||
+        var->iter_type == kVectorized ||
+        var->iter_type == kTensorized ||
+        var->iter_type == kParallelized)
+      << "Cannot vectorize on " << IterVarType2String(var->iter_type);
+  SetAttrIterType(operator->(), var, kVectorized);
+  return *this;
+}
+
+Stage& Stage::tensorize(IterVar var, TensorIntrin f) {   // NOLINT(*)
+  UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) {
+      n->iter_type = kTensorized;
+      n->tensor_intrin = f;
+    });
+  return *this;
+}
+
+Stage& Stage::unroll(IterVar var) {   // NOLINT(*)
+  SetAttrIterType(operator->(), var, kUnrolled);
+  return *this;
+}
+
+Stage& Stage::parallel(IterVar var) {   // NOLINT(*)
+  SetAttrIterType(operator->(), var, kParallelized);
+  return *this;
+}
+
+Stage& Stage::pragma(IterVar var,
+                     const std::string& pragma_type,
+                     const PrimExpr& pragma_value) {   // NOLINT(*)
+  if (pragma_type == "unroll") {
+    this->unroll(var);
+  } else if (pragma_type == "vectorize") {
+    this->vectorize(var);
+  } else {
+    UpdateIterVarAttr(
+        operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
+          n->pragma_keys.push_back(ir::StringImmNode::make(pragma_type));
+          n->pragma_values.push_back(pragma_value);
+        });
+  }
+  return *this;
+}
+
+Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) {
+  StageNode *self = operator->();
+  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+  FindLeafVar(all_vars, leaf_vars, var);
+  auto it = self->iter_var_attrs.find(var);
+  ObjectPtr<IterVarAttrNode> n;
+  if (it != self->iter_var_attrs.end()) {
+    n = make_object<IterVarAttrNode>(*(*it).second.operator->());
+  } else {
+    n = make_object<IterVarAttrNode>();
+  }
+  n->prefetch_data.push_back(tensor);
+  n->prefetch_offset.push_back(offset);
+  self->iter_var_attrs.Set(var, IterVarAttr(n));
+  return *this;
+}
+
+Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
+  StageNode *self = operator->();
+  UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) {
+      n->dim_align_factor = factor;
+      n->dim_align_offset = offset;
+    }, false);
+  return *this;
+}
+
+Stage& Stage::double_buffer() {
+  StageNode *self = operator->();
+  CHECK(!self->is_output) << "Cannot apply double buffer on output";
+  self->double_buffer = true;
+  return *this;
+}
+
+Stage& Stage::opengl() {
+  CHECK(!is_scheduled()) << "Must be a fresh schedule";
+  StageNode *self = operator->();
+
+  auto all_iter_vars = self->all_iter_vars;  // curr version of all_iter_vars
+  CHECK(!all_iter_vars.empty()) << "At least one iter var";
+
+  // Fuse all data parallel dimensions to 1.
+  IterVar fused = all_iter_vars[0];
+  for (size_t i = 1; i != all_iter_vars.size(); ++i) {
+    auto iter_var = all_iter_vars[i];
+    switch (iter_var->iter_type) {
+      case IterVarType::kDataPar: {
+        fuse(fused, all_iter_vars[i], &fused);
+        break;
+      }
+      case IterVarType::kThreadIndex: {
+        LOG(ERROR) << "A fresh schedule shouldn't have thread index iter var";
+        break;
+      }
+      case IterVarType::kCommReduce:
+      case IterVarType::kOrdered:
+      case IterVarType::kOpaque: {
+        break;
+      }
+      default: {
+        LOG(ERROR) << "Invalid iter var type "
+                   << IterVarType2String(iter_var->iter_type);
+        break;
+      }
+    }
+  }
+
+  // Bind the only dimension to threadIdx.x.
+  bind(fused, thread_axis(Range(nullptr), "threadIdx.x"));
+
+  // Mark this stage as OpenGL.
+  (*this)->is_opengl = true;
+
+  return *this;
+}
+
+Stage CopyStage(const Stage& s) {
+  ObjectPtr<StageNode> n =
+      make_object<StageNode>(*s.operator->());
+  return Stage(n);
+}
+
+Schedule Schedule::copy() const {
+  // map of stages.
+  const ScheduleNode* self = operator->();
+  std::unordered_map<Stage, Stage, ObjectHash, ObjectEqual> smap;
+  ObjectPtr<ScheduleNode> n = make_object<ScheduleNode>();
+  n->outputs = self->outputs;
+  // Copy the stages.
+  for (Stage s : self->stages) {
+    Stage scopy = CopyStage(s);
+    smap[s] = scopy;
+    n->stages.push_back(scopy);
+  }
+  for (Stage g : self->groups) {
+    Stage gcopy = CopyStage(g);
+    smap[g] = gcopy;
+    n->groups.push_back(gcopy);
+  }
+  // Remaps the reference relations.
+  for (auto kv : self->stage_map) {
+    n->stage_map.Set(kv.first, smap.at(kv.second));
+  }
+  for (Stage s : n->stages) {
+    if (s->attach_stage.defined()) {
+      CHECK(smap.find(s->attach_stage) != smap.end())
+        << s->attach_stage << " not found in " << (*this);
+      s->attach_stage = smap.at(s->attach_stage);
+    }
+    if (s->group.defined()) {
+      CHECK(smap.find(s->group) != smap.end())
+        << s->group << " not found in " << (*this);
+      s->group = smap.at(s->group);
+    }
+  }
+  for (Stage s : n->groups) {
+    if (s->attach_stage.defined()) {
+      CHECK(smap.find(s->attach_stage) != smap.end())
+        << s->attach_stage << " not found in " << (*this);
+      s->attach_stage = smap.at(s->attach_stage);
+    }
+    if (s->group.defined()) {
+      CHECK(smap.find(s->group) != smap.end())
+        << s->group << " not found in " << (*this);
+      s->group = smap.at(s->group);
+    }
+  }
+  return Schedule(n);
+}
+
+Stage Schedule::operator[](const Operation& op) {
+  auto it = (*this)->stage_map.find(op);
+  CHECK(it != (*this)->stage_map.end())
+      << "Cannot find Stage for operator " << op
+      << " in the schedule";
+  return (*it).second;
+}
+
+Stage LeastCommonAncestor(Stage g1, Stage g2) {
+  if (!g1.defined()) return g1;
+  if (!g2.defined()) return g2;
+  if (g1.same_as(g2)) return g1;
+  Stage g = g1;
+  while (g.defined()) {
+    if (g.same_as(g2)) return g2;
+    g = g->group;
+  }
+  g = g2;
+  while (g.defined()) {
+    if (g.same_as(g1)) return g1;
+    g = g->group;
+  }
+  return g;
+}
+
+Array<Tensor> RemapTensor(ScheduleNode* self,
+                          const Array<Tensor>& arr) {
+  self->InitCache();
+  const auto& op2stage_cache = self->op2stage_cache_;
+  Array<Tensor> ret;
+  for (Tensor t : arr) {
+    if (!op2stage_cache.count(t->op.get())) {
+      CHECK(self->stage_map.count(t->op))
+          << "Given tensor is not in the schedule plan";
+      t = self->stage_map[t->op]->op.output(t->value_index);
+    }
+    ret.push_back(t);
+  }
+  return ret;
+}
+
+// Group the schedule stages.
+Stage Schedule::create_group(const Array<Tensor>& outputs,
+                             const Array<Tensor>& inputs,
+                             bool include_inputs) {
+  ScheduleNode* self = operator->();
+  self->InitCache();
+  const auto& op2stage_cache = self->op2stage_cache_;
+  // Get the ops.
+  Array<Operation> ops = top::GetSubGraph(
+      RemapTensor(self, outputs),
+      RemapTensor(self, inputs),
+      include_inputs);
+  // local counter entry
+  // Automatically initialize to 0 during creation.
+  struct Entry {
+    int count{0};
+  };
+  // Map of group->touched counter
+  std::unordered_map<Stage, Entry, ObjectHash, ObjectEqual> counter;
+  // The parent group;
+  Stage parent_group;
+  // Detect common parent and child.
+  for (size_t i = 0; i < ops.size(); ++i) {
+    Operation op = ops[i];
+    auto it = op2stage_cache.find(op.get());
+    CHECK(it != op2stage_cache.end());
+    Stage op_group = it->second->group;
+    if (i == 0) {
+      parent_group = op_group;
+    } else {
+      parent_group = LeastCommonAncestor(parent_group, op_group);
+    }
+    if (op_group.defined()) {
+      ++counter[op_group].count;
+    }
+  }
+  // Create the new group stage.
+  Stage gstage(make_object<StageNode>());
+  gstage->group = parent_group;
+  if (parent_group.defined()) {
+    ++parent_group->num_child_stages;
+  }
+  // Propagate the counter statistics from by checking if subgroup
+  // Is full and propagate.
+  std::vector<Stage> stack;
+  for (auto &kv : counter) {
+    if (!kv.first.same_as(parent_group)) {
+      if (kv.first->num_child_stages == kv.second.count) {
+        stack.push_back(kv.first);
+      }
+    }
+  }
+  while (!stack.empty()) {
+    Stage g = stack.back();
+    stack.pop_back();
+    if (g->group.defined() && !g->group.same_as(parent_group)) {
+      Entry& e = counter[g->group];
+      ++e.count;
+      if (e.count == g->group->num_child_stages) {
+        stack.push_back(g->group);
+      }
+    }
+  }
+  // Verification and remappig the subgroups.
+  for (auto &kv : counter) {
+    if (kv.first.same_as(parent_group)) continue;
+    CHECK_EQ(kv.first->num_child_stages, kv.second.count)
+        << "Trying to group region that intersect with an already existed group";
+    if (kv.first->group.same_as(parent_group)) {
+      Stage s = kv.first;
+      s->group = gstage;
+      ++gstage->num_child_stages;
+      if (parent_group.defined()) {
+        --parent_group->num_child_stages;
+      }
+    }
+  }
+  // Remap the group of op stages.
+  for (Operation op : ops) {
+    auto it = op2stage_cache.find(op.get());
+    CHECK(it != op2stage_cache.end());
+    Stage s = it->second;
+    if (s->group.same_as(parent_group)) {
+      s->group = gstage;
+      ++gstage->num_child_stages;
+      if (parent_group.defined()) {
+        --parent_group->num_child_stages;
+      }
+    }
+  }
+  // Correct the attach to keep everything in group.
+  for (Operation op : ops) {
+    auto it = op2stage_cache.find(op.get());
+    CHECK(it != op2stage_cache.end());
+    Stage s = it->second;
+    if (s->attach_type == kScope) {
+      Stage cg = LeastCommonAncestor(s->attach_stage->group, gstage);
+      if (!cg.same_as(gstage)) {
+        LOG(WARNING) << "group invalidates some previous compute_at relation "
+                     << " and keeps things to be computed inside the group";
+        s.compute_root();
+      }
+    }
+  }
+
+  self->groups.push_back(gstage);
+  return gstage;
+}
+
+void ScheduleNode::InvalidateCache() {
+  op2stage_cache_.clear();
+}
+
+void ScheduleNode::InitCache() {
+  if (op2stage_cache_.size() == stages.size()) return;
+  InvalidateCache();
+  for (Stage s : stages) {
+    if (s->op.defined()) {
+      op2stage_cache_[s->op.get()] = s;
+    }
+  }
+  CHECK_EQ(op2stage_cache_.size(), stages.size());
+}
+
+bool ScheduleNode::Contain(const Operation& op) const {
+  return stage_map.find(op) != stage_map.end();
+}
+
+Schedule ScheduleNode::make(Array<Operation> ops) {
+  auto n = make_object<ScheduleNode>();
+  Schedule sch(n);
+  n->outputs = ops;
+  auto g = top::CreateReadGraph(n->outputs);
+  Array<Operation> post_order = top::PostDFSOrder(n->outputs, g);
+  // output set.
+  std::unordered_set<Operation> output_set;
+  for (Operation x : ops) {
+    output_set.insert(x);
+  }
+  for (Operation op : post_order) {
+    Stage stage(op);
+    stage->is_output = output_set.count(op) != 0;
+    n->stages.push_back(stage);
+    n->stage_map.Set(op, stage);
+    // mark scan updates.
+    if (const ScanOpNode* scan = op.as<ScanOpNode>()) {
+      Array<Tensor> inputs;
+      for (Tensor t : scan->state_placeholder) {
+        inputs.push_back(t);
+      }
+      for (Tensor t : scan->inputs) {
+        inputs.push_back(t);
+      }
+      // Create the scan group.
+      Stage scan_group = sch.create_group(scan->update, inputs, false);
+      scan_group->attach_type = kScanUpdate;
+      scan_group->attach_stage = stage;
+
+      for (size_t i = 0; i < scan->update.size(); ++i) {
+        Stage s = n->stage_map[scan->update[i]->op];
+        CHECK(scan_group.same_as(s->group));
+      }
+    }
+  }
+  return sch;
+}
+
+IterVarRelation SplitNode::make(IterVar parent,
+                                IterVar outer,
+                                IterVar inner,
+                                PrimExpr factor,
+                                PrimExpr nparts) {
+  auto n = make_object<SplitNode>();
+  n->parent = parent;
+  n->outer = outer;
+  n->inner = inner;
+  n->factor = factor;
+  n->nparts = nparts;
+  return IterVarRelation(n);
+}
+
+IterVarRelation FuseNode::make(
+    IterVar outer, IterVar inner, IterVar fused) {
+  auto n = make_object<FuseNode>();
+  n->outer = outer;
+  n->inner = inner;
+  n->fused = fused;
+  return IterVarRelation(n);
+}
+
+IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
+  auto n = make_object<RebaseNode>();
+  n->parent = parent;
+  n->rebased = rebased;
+  return IterVarRelation(n);
+}
+
+IterVarRelation SingletonNode::make(IterVar iter) {
+  auto n = make_object<SingletonNode>();
+  n->iter = iter;
+  return IterVarRelation(n);
+}
+
+TVM_REGISTER_NODE_TYPE(StageNode);
+TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
+TVM_REGISTER_NODE_TYPE(SplitNode);
+TVM_REGISTER_NODE_TYPE(FuseNode);
+TVM_REGISTER_NODE_TYPE(RebaseNode);
+TVM_REGISTER_NODE_TYPE(SingletonNode);
+TVM_REGISTER_NODE_TYPE(ScheduleNode);
+
+// Printer
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<StageNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const StageNode*>(node.get());
+    if (op->op.defined()) {
+      p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
+    } else {
+      p->stream << "group-stage(" << op << ")";
+    }
+})
+.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const IterVarAttrNode*>(node.get());
+    p->stream << IterVarType2String(op->iter_type);
+})
+.set_dispatch<SplitNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const SplitNode*>(node.get());
+    p->stream << "split(parent=";
+    p->Print(op->parent);
+    p->stream << ", outer=";
+    p->Print(op->outer);
+    p->stream << ", inner=";
+    p->Print(op->inner);
+    p->stream << ')';
+})
+.set_dispatch<FuseNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const FuseNode*>(node.get());
+    p->stream << "split(";
+    p->stream << "outer=";
+    p->Print(op->outer);
+    p->stream << ", inner=";
+    p->Print(op->inner);
+    p->stream << ", fused=";
+    p->Print(op->fused);
+    p->stream << ')';
+})
+.set_dispatch<RebaseNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const RebaseNode*>(node.get());
+    p->stream << "rebase(";
+    p->stream << "parent=";
+    p->Print(op->parent);
+    p->stream << ", rebased=";
+    p->Print(op->rebased);
+    p->stream << ')';
+})
+.set_dispatch<SingletonNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const SingletonNode*>(node.get());
+    p->stream << "singleton(";
+    p->Print(op->iter);
+    p->stream << ')';
+})
+.set_dispatch<ScheduleNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const ScheduleNode*>(node.get());
+    p->stream << "schedule(" << op << ")";
+  });
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/schedule/schedule_ops.cc b/src/top/schedule/schedule_ops.cc
new file mode 100644 (file)
index 0000000..1176d82
--- /dev/null
@@ -0,0 +1,427 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file schedule_ops.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_functor_ext.h>
+#include <tvm/top/operation.h>
+#include <tvm/top/schedule_pass.h>
+#include <utility>
+#include <unordered_map>
+#include <unordered_set>
+#include "graph.h"
+#include "../operation/op_util.h"
+#include "../../pass/ir_util.h"
+
+namespace tvm {
+namespace top {
+
+using namespace ir;
+
+Stmt MakePipeline(const Stage& s,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  Stmt consumer,
+                  bool debug_keep_trivial_loop) {
+  Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
+  if (producer.defined()) {
+    producer = ProducerConsumerNode::make(s->op, true, producer);
+  }
+  if (s->double_buffer) {
+    producer = AttrStmtNode::make(
+        s->op, ir::attr::double_buffer_scope, 1, producer);
+  }
+  Stmt pipeline = producer;
+
+  if (consumer.defined() && !is_no_op(consumer)) {
+    consumer = ProducerConsumerNode::make(s->op, false, consumer);
+    pipeline = SeqStmt({producer, consumer});
+  }
+  pipeline = s->op->BuildRealize(s, dom_map, pipeline);
+  // use attribute to mark scope of the operation.
+  pipeline = AttrStmtNode::make(
+      s->op, ir::attr::realize_scope,
+      StringImmNode::make(s->scope),
+      pipeline);
+
+  if (s->is_opengl) {
+    pipeline = AttrStmtNode::make(
+        s->op, ir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline);
+  }
+  return pipeline;
+}
+
+// inject the operator's realization on the stmt.
+class InjectAttach : public StmtMutator {
+ public:
+  InjectAttach(const Stage& stage,
+               const Stage& attach_spec,
+               const std::unordered_map<IterVar, Range>& dom_map,
+               bool debug_keep_trivial_loop)
+      : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
+        debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
+
+  Stmt VisitStmt(const Stmt& input_stmt) final {
+    CHECK(input_stmt.defined());
+    auto stmt = StmtMutator::VisitStmt(input_stmt);
+    const AttrStmtNode* op = stmt.as<AttrStmtNode>();
+    if (op != nullptr &&
+        op->attr_key == attr::loop_scope) {
+      if (attach_spec_->attach_type == kScope &&
+          op->node == attach_spec_->attach_ivar) {
+        CHECK(!found_attach)
+            << "Find IterVar" << attach_spec_->attach_ivar
+            << " in multiple places in the IR";
+        found_attach = true;
+        stmt = AttrStmtNode::make(
+            op->node, op->attr_key, op->value,
+            MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
+      }
+    }
+    return stmt;
+  }
+  // whether attach point is found
+  bool found_attach{false};
+
+ private:
+  // The stage.
+  const Stage& stage_;
+  // The attach spec, may not contain op.
+  const Stage& attach_spec_;
+  // domain map
+  const std::unordered_map<IterVar, Range>& dom_map_;
+  // Whether keep trivial loops with extent of 1 during lowering.
+  // This is a debug feature for dataflow/axis analysis
+  bool debug_keep_trivial_loop_;
+};
+
+// inject the operator's realization on the stmt.
+class InjectScanStep : public StmtMutator {
+ public:
+  InjectScanStep(const Stage& stage,
+                 const Operation& scan_op,
+                 const std::unordered_map<IterVar, Range>& dom_map,
+                 bool is_init,
+                 bool debug_keep_trivial_loop)
+      : stage_(stage), scan_op_(scan_op),
+        dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
+
+  Stmt VisitStmt(const Stmt& input_stmt) final {
+    CHECK(input_stmt.defined());
+    auto stmt = StmtMutator::VisitStmt(input_stmt);
+    // update
+    const AttrStmtNode* op = stmt.as<AttrStmtNode>();
+    if (op != nullptr &&
+        ((op->attr_key == attr::scan_update_scope && !is_init_) ||
+         (op->attr_key == attr::scan_init_scope && is_init_))) {
+      if (op->node.same_as(scan_op_)) {
+        found_attach = true;
+        stmt = AttrStmtNode::make(
+            op->node, op->attr_key, op->value,
+            MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
+      }
+    }
+    return stmt;
+  }
+
+  // whether attach point is found
+  bool found_attach{false};
+
+ private:
+  // the operations to be carried
+  const Stage& stage_;
+  const Operation& scan_op_;
+  // domain map
+  const std::unordered_map<IterVar, Range>& dom_map_;
+  // whether it is init.
+  bool is_init_;
+  // Whether keep trivial loops with extent of 1 during lowering.
+  // This is a debug feature for dataflow/axis analysis
+  bool debug_keep_trivial_loop_;
+};
+
+// Postprocessing of schedule op
+// Replace the init and update's expression by scan's buffer.
+class SchedulePostProc : public StmtExprMutator {
+ public:
+  Stmt VisitStmt_(const ProducerConsumerNode* op) final {
+    auto it = replace_op_.find(op->func.get());
+    if (it != replace_op_.end()) {
+      Stmt body = this->VisitStmt(op->body);
+      if (it->second.defined()) {
+        return ProducerConsumerNode::make(
+            it->second, op->is_producer, body);
+      } else {
+        return body;
+      }
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+  Stmt VisitStmt_(const LetStmtNode* op) final {
+    if (!HasSideEffect(op->value)) {
+      var_value_[op->var.get()] = this->VisitExpr(op->value);
+      return this->VisitStmt(op->body);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::loop_scope ||
+        op->attr_key == attr::scan_init_scope) {
+      return this->VisitStmt(op->body);
+    } else if (op->attr_key == attr::scan_update_scope) {
+      const ScanOpNode* scan = op->node.as<ScanOpNode>();
+      CHECK(scan);
+      var_value_[scan->scan_axis->var.get()] = op->value;
+      return this->VisitStmt(op->body);
+    } else if (op->attr_key == attr::thread_extent) {
+      // delete duplicated thread extent attr
+      auto it = thread_extent_scope_.find(op->node.get());
+      if (it != thread_extent_scope_.end()) {
+        CHECK(is_zero(ir::Simplify(it->second - op->value)));
+        return this->VisitStmt(op->body);
+      } else {
+        thread_extent_scope_[op->node.get()] = op->value;
+        Stmt ret = StmtExprMutator::VisitStmt_(op);
+        thread_extent_scope_.erase(op->node.get());
+        return ret;
+      }
+    } else if (op->attr_key == ir::attr::realize_scope ||
+               op->attr_key == ir::attr::double_buffer_scope) {
+      auto it = replace_op_.find(op->node.get());
+      if (it != replace_op_.end()) {
+        if (it->second.defined()) {
+          Stmt ret = AttrStmtNode::make(
+              it->second, op->attr_key, op->value, op->body);
+          return this->VisitStmt(ret);
+        } else {
+          return this->VisitStmt(op->body);
+        }
+      }
+    } else if (op->attr_key == ir::attr::buffer_bind_scope) {
+      Array<ObjectRef> tuple = Downcast<Array<ObjectRef> >(op->node);
+      Tensor tensor = Downcast<Tensor>(tuple[1]);
+      auto it = replace_op_.find(tensor->op.get());
+      if (it != replace_op_.end()) {
+        if (it->second.defined()) {
+          return AttrStmtNode::make(
+              Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
+              op->attr_key, op->value, this->VisitStmt(op->body));
+        } else {
+          return this->VisitStmt(op->body);
+        }
+      }
+    } else if (op->attr_key == ir::attr::buffer_dim_align) {
+      Tensor tensor = Downcast<Tensor>(op->node);
+      auto it = replace_op_.find(tensor->op.get());
+      if (it != replace_op_.end()) {
+        if (it->second.defined()) {
+          return AttrStmtNode::make(
+              it->second.output(tensor->value_index),
+              op->attr_key, op->value, this->VisitStmt(op->body));
+        } else {
+          return this->VisitStmt(op->body);
+        }
+      }
+    }
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+  Stmt VisitStmt_(const RealizeNode* op) final {
+    TensorKey key{op->func, op->value_index};
+    auto it = replace_realize_.find(key);
+    if (it != replace_realize_.end()) {
+      if (it->second.defined()) {
+        Stmt ret = RealizeNode::make(
+            it->second->op, it->second->value_index,
+            op->dtype, op->bounds, op->condition, op->body);
+        return this->VisitStmt(ret);
+      } else {
+        return this->VisitStmt(op->body);
+      }
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  Stmt VisitStmt_(const ProvideNode* op) final {
+    TensorKey key{op->func, op->value_index};
+    auto it = replace_buffer_.find(key);
+    if (it != replace_buffer_.end()) {
+      const Tensor& dst = it->second;
+      Stmt ret = ProvideNode::make(
+          dst->op, dst->value_index, op->value, op->args);
+      return this->VisitStmt(ret);
+    } else {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    if (op->call_type == CallNode::Halide) {
+      TensorKey key{op->func, op->value_index};
+      auto it = replace_buffer_.find(key);
+      if (it != replace_buffer_.end()) {
+        const Tensor& dst = it->second;
+        PrimExpr ret = CallNode::make(
+            op->dtype, dst->op->name, op->args,
+            op->call_type, dst->op, dst->value_index);
+        return this->VisitExpr(ret);
+      }
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_value_.find(op);
+    if (it != var_value_.end()) {
+      return it->second;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  void Init(const Schedule& sch) {
+    for (Stage s : sch->stages) {
+      for (auto kv : s->iter_var_attrs) {
+        // Update bind thread information.
+        if (kv.second->bind_thread.defined()) {
+          const Var& from = kv.first->var;
+          const Var& to = kv.second->bind_thread->var;
+          CHECK(!var_value_.count(from.get()));
+          var_value_[from.get()] = to;
+        }
+      }
+      // This must be checked for all ops, including scan.
+      if (!s->op.same_as(s->origin_op)) {
+        for (int i = 0; i < s->op->num_outputs(); ++i) {
+          Tensor target = s->origin_op.output(i);
+          AddReplace(s->op.output(i), target,
+                     target, s->origin_op);
+        }
+      }
+      // Specially add replacements for scan op.
+      if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
+        for (size_t i = 0; i < scan->update.size(); ++i) {
+          Tensor t = s->origin_op.output(i);
+          AddReplace(scan->init[i], t);
+          AddReplace(scan->update[i], t);
+          AddReplace(scan->state_placeholder[i], t);
+        }
+      }
+    }
+  }
+
+ private:
+  void AddReplace(Tensor src,
+                  Tensor dst,
+                  Tensor repl_realize = Tensor(),
+                  Operation repl_op = Operation()) {
+    TensorKey key{src->op, src->value_index};
+    replace_buffer_[key] = dst;
+    replace_realize_[key] = repl_realize;
+    replace_op_[src->op.get()] = repl_op;
+  }
+  // The thread extent scope.
+  std::unordered_map<const Object*, PrimExpr> thread_extent_scope_;
+  // The scan value
+  std::unordered_map<const VarNode*, PrimExpr> var_value_;
+  // buffer replacement
+  std::unordered_map<TensorKey, Tensor> replace_buffer_;
+  // buffere realization to be replaced
+  std::unordered_map<TensorKey, Tensor> replace_realize_;
+  // replace producer consumer.
+  std::unordered_map<const Object*, Operation> replace_op_;
+};
+
+Stmt ScheduleOps(
+    Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
+  Stmt body = Stmt();
+  std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
+  // scan init and scan updates
+  std::unordered_map<Operation, Operation> scan_init;
+  for (Stage s : sch->stages) {
+    const ScanOpNode* scan = s->op.as<ScanOpNode>();
+    if (!scan) continue;
+    for (Tensor t : scan->init) {
+      if (scan_init.count(t->op)) {
+        CHECK(scan_init.at(t->op).same_as(s->op))
+            << "Scan init tensor can only belong to one scan";
+      } else {
+        scan_init[t->op] = s->op;
+      }
+    }
+  }
+  // verify correctness of group.
+  for (Stage g : sch->groups) {
+    CHECK(!g->op.defined());
+    CHECK_EQ(g->leaf_iter_vars.size(), 0U);
+  }
+  // reverse the post DFS order.
+  for (size_t i = sch->stages.size(); i != 0; --i) {
+    Stage s = sch->stages[i - 1];
+    CHECK_NE(s->attach_type, kInline)
+        << "call schedule.normalize before scheduleops";
+    CHECK(s->op.defined());
+    // no need to specify place holder op.
+    if (s->op.as<PlaceholderOpNode>()) continue;
+    // Remove grouping sugar, get the real attach spec.
+    Stage attach_spec = s.GetAttachSpec();
+
+    if (scan_init.count(s->op)) {
+      CHECK(body.defined());
+      InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
+      body = mu(std::move(body));
+      CHECK(mu.found_attach)
+          << "did not find attachment point for scan.init";
+    } else if (attach_spec->attach_type == kScanUpdate) {
+      // Handle scan update
+      CHECK(body.defined());
+      InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
+      body = mu(std::move(body));
+      CHECK(mu.found_attach)
+          << "did not find attachment point for scan.update";
+    } else if (attach_spec->attach_type == kInlinedAlready) {
+      // do nothing
+    } else if (attach_spec->attach_type == kGroupRoot) {
+      CHECK(!s->group.defined());
+      body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
+    } else {
+      CHECK_EQ(attach_spec->attach_type, kScope);
+      CHECK(body.defined());
+      InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
+      body = mutator(std::move(body));
+      CHECK(mutator.found_attach)
+          << "did not find attachment point for " << s << " in "
+          << attach_spec->attach_stage->op  << " x " << attach_spec->attach_ivar
+          << ", body:\n"
+          << body;
+    }
+  }
+  SchedulePostProc post_proc;
+  post_proc.Init(sch);
+  return post_proc(std::move(body));
+}
+
+}  // namespace top
+}  // namespace tvm
diff --git a/src/top/tensor.cc b/src/top/tensor.cc
new file mode 100644 (file)
index 0000000..c8e3aea
--- /dev/null
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tensor.cc
+ */
+#include <tvm/top/tensor.h>
+#include <tvm/top/operation.h>
+#include <tvm/top/tensor_intrin.h>
+#include <memory>
+
+namespace tvm {
+namespace top {
+// Tensor
+PrimExpr Tensor::operator()(Array<Var> indices) const {
+  Array<PrimExpr> arr(indices.begin(), indices.end());
+  return operator()(arr);
+}
+
+PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
+  using ir::CallNode;
+  if (ndim() != 0) {
+    CHECK_EQ(ndim(), indices.size())
+        << "Tensor dimension mismatch in read"
+        << "ndim = " << ndim() << ", indices.size=" << indices.size();
+  }
+  auto n = CallNode::make(
+      (*this)->dtype, (*this)->op->name, indices, CallNode::Halide,
+      (*this)->op, (*this)->value_index);
+  return n;
+}
+
+Tensor Operation::output(size_t i) const {
+  auto node = make_object<TensorNode>();
+  node->op = *this;
+  node->value_index = i;
+  node->dtype = (*this)->output_dtype(i);
+  node->shape = (*this)->output_shape(i);
+  return Tensor(node);
+}
+
+Tensor TensorNode::make(Array<PrimExpr> shape,
+                        DataType dtype,
+                        Operation op,
+                        int value_index) {
+  auto n = make_object<TensorNode>();
+  n->shape = std::move(shape);
+  n->dtype = dtype;
+  n->op = op;
+  n->value_index = value_index;
+  return Tensor(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<TensorNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* t = static_cast<const TensorNode*>(node.get());
+    p->stream << "Tensor(shape=" << t->shape
+              << ", op.name=" << t->op->name << ')';
+  });
+
+TVM_REGISTER_NODE_TYPE(TensorNode);
+
+
+// TensorIntrin
+
+TensorIntrin TensorIntrinNode::make(std::string name,
+                                    Operation op,
+                                    Array<Tensor> inputs,
+                                    Array<Buffer> buffers,
+                                    Array<Var> scalar_params,
+                                    Stmt body,
+                                    Stmt reduce_init,
+                                    Stmt reduce_update) {
+  auto n = make_object<TensorIntrinNode>();
+  n->name = std::move(name);
+  n->op = std::move(op);
+  n->inputs = std::move(inputs);
+  n->buffers = std::move(buffers);
+  n->scalar_params = std::move(scalar_params);
+  n->body = std::move(body);
+  n->reduce_init = std::move(reduce_init);
+  n->reduce_update = std::move(reduce_update);
+  return TensorIntrin(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const TensorIntrinNode*>(node.get());
+    p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
+  });
+
+TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
+
+
+// TensorIntrinCall
+
+TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
+                                            Array<Tensor> tensors,
+                                            Array<Region> regions,
+                                            Array<IterVar> reduce_axis,
+                                            Array<PrimExpr> scalar_inputs) {
+  auto n = make_object<TensorIntrinCallNode>();
+  n->intrin = std::move(intrin);
+  n->tensors = std::move(tensors);
+  n->regions = std::move(regions);
+  n->reduce_axis = std::move(reduce_axis);
+  n->scalar_inputs = std::move(scalar_inputs);
+  return TensorIntrinCall(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
+    p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
+  });
+
+TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
+
+}  // namespace top
+}  // namespace tvm
index 15bdd474f8039f10086a4c8076f9282a206928bf..30834c5af46ce9ff947902fc1b1e0448e80645b5 100644 (file)
@@ -20,7 +20,7 @@
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
 #include <topi/cuda/injective.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 #include <tvm/build_module.h>
@@ -30,6 +30,7 @@
 
 TEST(BuildModule, Basic) {
   using namespace tvm;
+  using namespace tvm::top;
   auto n = var("n");
   Array<PrimExpr> shape;
   shape.push_back(n);
@@ -75,6 +76,7 @@ TEST(BuildModule, Heterogeneous) {
    */
 
   using namespace tvm;
+  using namespace tvm::top;
   const runtime::PackedFunc* pf = runtime::Registry::Get("module._Enabled");
   bool enabled = (*pf)("cuda");
   if (!enabled) {
index d9b620063f563dc109761973f95a40587f5dc073..af8ede396edcd8d86ac82ecc2f65d7359997c415 100644 (file)
@@ -19,7 +19,7 @@
 
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 
 TEST(Expr, Basic) {
   using namespace tvm;
index 6b694eff28f3652731a9ce4ad319a13e58b8d903..e9f0df6493a6d507c26cbd5fc301b2ff9f03dfb4 100644 (file)
@@ -20,7 +20,7 @@
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
 #include <tvm/ir_pass.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 
 TEST(IRSIMPLIFY, MinMax) {
   auto x = tvm::var("x");
index df2a6b6d97176f549e17efffe2e8686f143b573a..462d0fe12eba9875c6aa3ef765a316fd1caa4d4f 100644 (file)
@@ -19,7 +19,7 @@
 
 #include <gtest/gtest.h>
 #include <tvm/build_module.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/type.h>
 #include <tvm/relay/analysis.h>
index 7d03d2e0ef87f079af455a306e5bc3640f714c53..68d5d0d67f743d885c2eff4c30fa54d1e21a9db5 100644 (file)
@@ -18,7 +18,7 @@
  */
 
 #include <gtest/gtest.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/type.h>
 #include <tvm/relay/analysis.h>
index 8321c580ab0ff646d0717037c562fd757ab0381e..4c383b5af4e1283eb5824591f5f03e4fc8d51f60 100644 (file)
@@ -28,7 +28,7 @@
 #include <tvm/relay/type.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 
 TVM_REGISTER_GLOBAL("schedule")
     .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
index 6333f1549bec28559d0d8a2171be9ccdb56aaccc..e41b881faa43b07043dd6c733d3747e054e5c26d 100644 (file)
@@ -20,7 +20,7 @@
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
 #include <tvm/ir_pass.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 
 TEST(SimplePasses, HasSideEffect) {
   using namespace tvm;
@@ -28,7 +28,7 @@ TEST(SimplePasses, HasSideEffect) {
   Array<PrimExpr> shape;
   shape.push_back(n);
 
-  auto A = placeholder(shape, DataType::Float(32), "A");
+  auto A = top::placeholder(shape, DataType::Float(32), "A");
 
   CHECK(!tvm::ir::HasSideEffect(A[0]));
 }
index c994c0fbc355a8f38d88f2f7deebf4519e9b0c43..5d6dc239f5c3b42e63d1b545511a837a272848bc 100644 (file)
 
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 
 TEST(Tensor, Basic) {
   using namespace tvm;
+  using namespace tvm::top;
+
   Var m("m"), n("n"), l("l");
 
   Tensor A = placeholder({m, l}, DataType::Float(32), "A");
@@ -37,15 +39,17 @@ TEST(Tensor, Basic) {
 
 TEST(Tensor, Reduce) {
   using namespace tvm;
+  using namespace tvm::top;
+
   Var m("m"), n("n"), l("l");
-  Tensor A = placeholder({m, l}, DataType::Float(32), "A");
-  Tensor B = placeholder({n, l}, DataType::Float(32), "B");
+  top::Tensor A = top::placeholder({m, l}, DataType::Float(32), "A");
+  top::Tensor B = top::placeholder({n, l}, DataType::Float(32), "B");
   IterVar rv = reduce_axis(Range{0, l}, "k");
 
-  auto C = compute({m, n}, [&](Var i, Var j) {
+  auto C = top::compute({m, n}, [&](Var i, Var j) {
       return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv});
       }, "C");
-  LOG(INFO) << C->op.as<ComputeOpNode>()->body;
+  LOG(INFO) << C->op.as<top::ComputeOpNode>()->body;
 }
 
 int main(int argc, char ** argv) {
index 55a91452c18972cdaaf5f50d974d42f03e8d6450..5f89bdf9641ae836128a407a60a811a6e98d3374 100644 (file)
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <topi/elemwise.h>
 #include <gtest/gtest.h>
 
index a3720cd4d1a2d9352cca3a49834fcd2187aab7ed..73a624584749586055b9bcc2cd25fe48a259e323 100644 (file)
@@ -32,7 +32,7 @@
 #include <gtest/gtest.h>
 #include <topi/generic/injective.h>
 #include <tvm/build_module.h>
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <tvm/packed_func_ext.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr.h>
index ce16e23bf4fa7f1fda0ffafa390673a297114829..47dfc3e77f09de0e627e4dd9d75063b61c08eb5f 100644 (file)
@@ -43,7 +43,7 @@ namespace topi {
  *
  * \return A Tensor whose op member is a broadcast operation
  */
-inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
+inline tvm::top::Tensor broadcast_to(const tvm::top::Tensor& t,
                                 const tvm::Array<tvm::PrimExpr>& output_shape,
                                 std::string name = "T_broadcast_to",
                                 std::string tag = kBroadcast) {
@@ -58,7 +58,7 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
   auto l = [&](tvm::Array<tvm::Var> ovars) {
     return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
   };
-  return tvm::compute(
+  return tvm::top::compute(
       tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
       l,
       name,
@@ -70,44 +70,44 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
                             const tvm::PrimExpr& b) {                 \
     ComputeRule;                                                      \
   }                                                                   \
-  inline tvm::Tensor Name(const tvm::Tensor& A,                       \
-                          const tvm::Tensor& B,                       \
-                          std::string name = "T_" #Name,              \
-                          std::string tag = kBroadcast) {             \
+  inline tvm::top::Tensor Name(const tvm::top::Tensor& A,             \
+                               const tvm::top::Tensor& B,             \
+                               std::string name = "T_" #Name,         \
+                               std::string tag = kBroadcast) {        \
     auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; };   \
     return detail::WithBroadcast(l, A, B, name, tag);                 \
   }                                                                   \
-  inline tvm::Tensor Name(const tvm::Tensor& A,                       \
+  inline tvm::top::Tensor Name(const tvm::top::Tensor& A,             \
                           const tvm::PrimExpr& B,                     \
                           std::string name = "T_" #Name,              \
                           std::string tag = kElementWise) {           \
-    auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
-    return compute(A->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
+    auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; };   \
+    return tvm::top::compute(A->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
         return l(A(i), B);                                            \
       }, name, tag);                                                  \
   }                                                                   \
-  inline tvm::Tensor Name(const tvm::PrimExpr& A,                     \
-                          const tvm::Tensor& B,                       \
-                          std::string name = "T_" #Name,              \
-                          std::string tag = kElementWise) {           \
+  inline tvm::top::Tensor Name(const tvm::PrimExpr& A,                \
+                               const tvm::top::Tensor& B,             \
+                               std::string name = "T_" #Name,         \
+                               std::string tag = kElementWise) {      \
     auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; };  \
-    return compute(B->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
+    return tvm::top::compute(B->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
         return l(A, B(i));                                            \
       }, name, tag);                                                  \
   }
 
 
 #define TOPI_DEFINE_OP_OVERLOAD(Name, OpName)                       \
-  inline tvm::Tensor Name(const tvm::Tensor& A,                     \
-                          const tvm::Tensor& B) {                   \
+  inline tvm::top::Tensor Name(const tvm::top::Tensor& A,           \
+                              const tvm::top::Tensor& B) {          \
     return topi::OpName(A, B);                                      \
   }                                                                 \
-  inline tvm::Tensor Name(const tvm::PrimExpr& A,                   \
-                          const tvm::Tensor& B) {                   \
+  inline tvm::top::Tensor Name(const tvm::PrimExpr& A,              \
+                               const tvm::top::Tensor& B) {         \
     return topi::OpName(A, B);                                      \
   }                                                                 \
-  inline tvm::Tensor Name(const tvm::Tensor& A,                     \
-                          const tvm::PrimExpr& B) {                 \
+  inline tvm::top::Tensor Name(const tvm::top::Tensor& A,           \
+                               const tvm::PrimExpr& B) {            \
     return topi::OpName(A, B);                                      \
   }
 
index c64490330f428af6d39f5a065080dae274d8bfb4..44685fc165a2fa882eb669373edfc6eb8d939e02 100644 (file)
 #ifndef TOPI_CONTRIB_CUBLAS_H_
 #define TOPI_CONTRIB_CUBLAS_H_
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "topi/detail/extern.h"
 
 namespace topi {
 namespace contrib {
 using namespace tvm;
+using namespace tvm::top;
 using namespace topi::detail;
 /*!
 * \brief Create an op that multiplies lhs and rhs with cuBLAS
index a61499dc488f0b53fd95bab368b94fe143fb7ad8..062ad405c3d1303aecb56fc778556130fdbca91d 100644 (file)
 #ifndef TOPI_CONTRIB_ROCBLAS_H_
 #define TOPI_CONTRIB_ROCBLAS_H_
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "topi/detail/extern.h"
 
 namespace topi {
 namespace contrib {
 using namespace tvm;
+using namespace tvm::top;
 /*!
 * \brief Create an op that multiplies lhs and rhs with rocBLAS
 *
index 781258afa982cbc81b362c0406e83d73125f36f3..637a861cd989bb87dfec82bbf208a23df3b398cb 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TOPI_CUDA_DENSE_H_
 #define TOPI_CUDA_DENSE_H_
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 #include "topi/tags.h"
 #include "topi/detail/array_utils.h"
@@ -34,6 +34,7 @@
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace cuda {
 /*!
@@ -47,10 +48,10 @@ namespace cuda {
 *
 * \return Tensor with shape [batch, out_dim]
 */
-inline tvm::Tensor dense_cuda(const Target& target,
-                              const tvm::Tensor& data,
-                              const tvm::Tensor& weight,
-                              const tvm::Tensor& bias,
+inline tvm::top::Tensor dense_cuda(const Target& target,
+                              const tvm::top::Tensor& data,
+                              const tvm::top::Tensor& weight,
+                              const tvm::top::Tensor& bias,
                               const DataType& out_dtype) {
   CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
   CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
@@ -66,7 +67,7 @@ inline tvm::Tensor dense_cuda(const Target& target,
     CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
     auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
     if (bias.defined()) {
-      mm = tvm::compute({ batch, out_dim },
+      mm = tvm::top::compute({ batch, out_dim },
                         [&](Var i, Var j) {
                           return mm(i, j) + bias(j);
                         }, "tensor", kBroadcast);
index 663bc1fcbeb238a67026c9f8af19614c85737bf9..5e42126729cca946de3832535f1a0ad2db0a6d67 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace cuda {
 
@@ -39,7 +40,7 @@ namespace cuda {
  *
  * \param sch The schedule to update.
  * \param out The tensor representing the injective op.
- * 
+ *
  * \return The updated schedule.
  */
 inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
@@ -67,7 +68,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
     out_ops.push_back(t->op);
   }
   auto s = create_schedule(out_ops);
-  tvm::schedule::AutoInlineInjective(s);
+  tvm::top::AutoInlineInjective(s);
   for (auto out : outs) {
     schedule_injective_from_existing(s, out);
   }
index 302578644555427dbd3b228581ceb2b9e57671a0..708f8d5e7bbff81009341cc6f537003b46d6e044 100644 (file)
 #ifndef TOPI_CUDA_NORMALIZATION_H_
 #define TOPI_CUDA_NORMALIZATION_H_
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 #include "topi/tags.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 namespace cuda {
 /*!
 * \brief Create a CUDA schedule for LRN
index df4ab33626345c94edaf3848b45b6f1fbc392c79..d2a5c1f4511f0e3bcc691aabb9e655ea2beea1eb 100644 (file)
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
 #include "topi/detail/array_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace cuda {
 
index a82b363066174caaf4d4714f954192dbed91bd1b..244567499d3fbed5395b018f1f2308e83300be6d 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace cuda {
 /*!
index 33be89922fea63f0797463694a15bef06bef6fd3..6f12de000bf82e68b36cfa624ba6128d8038d413 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace cuda {
 
@@ -50,9 +51,9 @@ inline Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs
   auto s = create_schedule(out_ops);
 
   auto softmax = outs[0];
-  tvm::Tensor max_elem;
-  tvm::Tensor expsum;
-  tvm::Tensor exp;
+  tvm::top::Tensor max_elem;
+  tvm::top::Tensor expsum;
+  tvm::top::Tensor exp;
   bool has_exp = false;
 
   auto tag = softmax->op.as<ComputeOpNode>()->tag;
index 93e7c3e156fb18b51a5e756bcb222e43118d67e4..0c0feec2eaa78cb5100cb86401eb0eeef59d162c 100644 (file)
 #ifndef TOPI_DETAIL_ARRAY_UTILS_H_
 #define TOPI_DETAIL_ARRAY_UTILS_H_
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 
 namespace topi {
 namespace detail {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
  * \brief Search an array for a specific item
index 2d326e7e1e072f0db666808e5c8af2f469a3f753..2e644eebdc8c9251d8e4c2c619cbb17ebda1184e 100644 (file)
@@ -29,7 +29,7 @@
 #include <string>
 
 #include "tvm/ir_pass.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/expr_operator.h"
 #include "topi/detail/constant_utils.h"
 
@@ -100,7 +100,7 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
 
 inline tvm::Array<tvm::PrimExpr> InputIndexFromBroadcast(
     const tvm::Array<tvm::Var>& ovars,
-    const tvm::Tensor& T,
+    const tvm::top::Tensor& T,
     const std::deque<tvm::Var>& my_vars,
     const std::deque<tvm::Var>& all_vars) {
   tvm::Array<tvm::PrimExpr> ivars;
@@ -127,9 +127,9 @@ inline tvm::Array<tvm::PrimExpr> InputIndexFromBroadcast(
 }
 
 template <typename FBinaryExpr>
-inline tvm::Tensor WithBroadcast(FBinaryExpr op,
-                                 const tvm::Tensor& A,
-                                 const tvm::Tensor& B,
+inline tvm::top::Tensor WithBroadcast(FBinaryExpr op,
+                                 const tvm::top::Tensor& A,
+                                 const tvm::top::Tensor& B,
                                  const std::string& name = "tensor",
                                  const std::string& tag = "") {
   auto bh = BroadcastShape(A->shape, B->shape);
@@ -137,7 +137,7 @@ inline tvm::Tensor WithBroadcast(FBinaryExpr op,
     return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
               B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
   };
-  return tvm::compute(
+  return tvm::top::compute(
       tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
       l,
       name,
index e6de76f20881309dbcf63de2167f582de3d133a5..210049344404cf8a2eb7f405e801d060505e4a2a 100644 (file)
@@ -33,6 +33,7 @@
 namespace topi {
 namespace detail {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
  * \brief Test whether the given Expr is a constant integer
index 8bdda802ad7a005ba53f81425837a1eb8331627b..717ce4d46d334f20bed12a5fb8be8f6f7f2d2bc0 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TOPI_DETAIL_EXTERN_H_
 #define TOPI_DETAIL_EXTERN_H_
 
-#include <tvm/operation.h>
+#include <tvm/top/operation.h>
 #include <vector>
 #include <string>
 
@@ -32,6 +32,7 @@
 namespace topi {
 namespace detail {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
  * \brief Construct a buffer to pass to an external function
index d573a3fcb18a7a677402a3da22f0b89772b66cd8..5a77db6b31f295f62cad7cb5e762c8fcd234fa8b 100644 (file)
 #ifndef TOPI_DETAIL_FUSE_H_
 #define TOPI_DETAIL_FUSE_H_
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 
 namespace topi {
 namespace detail {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
  * \brief Fuse all of the given args
index ec757e9954fe0b4d9a55d7f631e3f6e687762fd6..12b15413e72af0422f02c4b1e10103ea092353b0 100644 (file)
@@ -32,6 +32,7 @@
 namespace topi {
 namespace detail {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
  * \brief Get padding size for each side given padding height and width
index 5526a7dac7c2eeebad1e9752fd582d983e196642..c8da45d918b26c53ce1f833eee78ea99072dbc11 100644 (file)
 
 #include <vector>
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/expr_operator.h"
 
 namespace topi {
 namespace detail {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief Flatten the indices to 1D
index e52452e89d33e91d1cfaf327381d2cfb9b28b870..e306880284db6ec5a2673ac032d1e8464bf4a1f3 100644 (file)
@@ -28,6 +28,7 @@
 namespace topi {
 namespace detail {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
  * \brief Check whether input shape has dimension of size 0;
index df7cff0643830306785db125a9dd6c7fc2829ba2..46515e7632267d2e90b69eae5472558d02639135 100644 (file)
@@ -33,6 +33,7 @@
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 // Unary intrinsic operators
 #define TOPI_DECLARE_UNARY_OP(OpName)                           \
index 9dff8d53d6afdcc3b83fa17ca4da8c4250630a91..540e4a17f148e267d2e9479cda8de916c2636f5c 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace generic {
 /*!
@@ -66,7 +67,7 @@ inline Schedule default_schedule_auto_inline(const Target& target, Array<Tensor>
   }
   auto s = create_schedule(out_ops);
   auto x = outs[0];
-  tvm::schedule::AutoInlineInjective(s);
+  tvm::top::AutoInlineInjective(s);
   auto axis = s[x]->op.as<ComputeOpNode>()->axis;
   if (axis.size() > 0) {
     detail::Fuse(s[x], axis);
index 03e362c399df4479525fa8c73cec141bac67dae8..ea1e7b6531a64ce35d39dd7f95e10e8e1e8665e9 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 #include "injective.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace generic {
 /*!
@@ -49,7 +50,7 @@ inline Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
   }
   auto s = create_schedule(out_ops);
 
-  tvm::schedule::AutoInlineInjective(s);
+  tvm::top::AutoInlineInjective(s);
   for (auto out : outs) {
     if (out->op->IsInstance<ExternOpNode>()) {
       continue;
index fa7df4ccfd460f82a6d89ad1b00976e7316bbc1e..7df2e15391510cd8ba2707cb48e945af66eff5af 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace generic {
 
@@ -39,7 +40,7 @@ namespace generic {
  *
  * \param sch The schedule to update.
  * \param out The tensor representing the injective op.
- * 
+ *
  * \return The updated schedule.
  */
 inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
@@ -61,7 +62,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
     out_ops.push_back(t->op);
   }
   auto s = create_schedule(out_ops);
-  tvm::schedule::AutoInlineInjective(s);
+  tvm::top::AutoInlineInjective(s);
   auto x = outs[0];
   schedule_injective_from_existing(s, x);
 
index f169ec94663678a82a0ad3dff82a3f206a825222..7c1bad3dc9a73b212199a3452c027af32ffe1dea 100644 (file)
 #include "topi/elemwise.h"
 #include "topi/detail/ravel_unravel.h"
 #include "topi/detail/constant_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/expr_operator.h"
 
 namespace topi {
 namespace image {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
  * \brief Sample a point in a tensor using bilinear interpolation.
index 3f65c75a02bba2eb83956519b23a81d78ba1b6f3..b86c00c60ae36361392886ad4f9a6658efdda896 100644 (file)
 #include "topi/detail/constant_utils.h"
 #include "tvm/ir.h"
 #include "tvm/ir_pass.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/expr_operator.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 namespace detail {
 
 template <typename T>
@@ -61,11 +62,11 @@ tvm::PrimExpr Map(const tvm::Array<tvm::PrimExpr>& exprs, T op) {
  * \return A Tensor whose op member is the relu operation
  */
 template <typename T>
-inline tvm::Tensor relu(const tvm::Tensor& t,
+inline tvm::top::Tensor relu(const tvm::top::Tensor& t,
                         T threshold = static_cast<T>(0),
                         std::string name = "T_relu",
                         std::string tag = kElementWise) {
-  return tvm::compute(
+  return tvm::top::compute(
       t->shape,
       [&](const tvm::Array<tvm::Var>& i) {
         auto threshold_const = tvm::make_const(t->dtype, threshold);
@@ -85,11 +86,11 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
 *
 * \return A Tensor whose op member is the leaky relu operation
 */
-inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
+inline tvm::top::Tensor leaky_relu(const tvm::top::Tensor& t,
                               double alpha = 0.1,
                               std::string name = "T_leaky_relu",
                               std::string tag = kElementWise) {
-  return tvm::compute(
+  return tvm::top::compute(
     t->shape,
     [&](const tvm::Array<tvm::Var>& i) {
       auto value = t(i);
@@ -111,8 +112,8 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
  *
  * \return A Tensor whose op member is the parametric relu operation
  */
-inline tvm::Tensor prelu(const tvm::Tensor &x,
-                         const tvm::Tensor &slope,
+inline tvm::top::Tensor prelu(const tvm::top::Tensor &x,
+                         const tvm::top::Tensor &slope,
                          const int axis = 1,
                          std::string name = "T_prelu",
                          std::string tag = kBroadcast) {
@@ -122,7 +123,7 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
         topi::detail::GetConstInt(x->shape[axis]))
         << "Wrong slope shape received.";
 
-  return tvm::compute(x->shape,
+  return tvm::top::compute(x->shape,
                      [&](const tvm::Array<tvm::Var> &indices) {
                         auto xval = x(indices);
                         return tvm::ir::SelectNode::make(
@@ -171,7 +172,7 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
  *
  *
  */
-inline tvm::Tensor pad(const tvm::Tensor& t,
+inline tvm::top::Tensor pad(const tvm::top::Tensor& t,
                        const tvm::Array<tvm::PrimExpr>& pad_before,
                        tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(),
                        PrimExpr pad_value = PrimExpr(),
@@ -251,7 +252,7 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
     }
     return t(indices);
   };
-  return tvm::compute(output_shape, l, name, tag);
+  return tvm::top::compute(output_shape, l, name, tag);
 }
 
 /*!
@@ -274,8 +275,8 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
  * \return A Tensor whose op member is the 2-D convolution operation (NCHW
  * layout)
  */
-inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
-                               const tvm::Tensor& W,
+inline tvm::top::Tensor conv2d_nchw(const tvm::top::Tensor& I,
+                               const tvm::top::Tensor& W,
                                int pad_h = 0,
                                int pad_w = 0,
                                int stride_h = 1,
@@ -303,7 +304,7 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
         T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw),
         {i, kh, kw});
   };
-  return tvm::compute(output_shape, l, name, tag);
+  return tvm::top::compute(output_shape, l, name, tag);
 }
 
 /*!
@@ -325,8 +326,8 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
  * \return A Tensor whose op member is the 2-D convolution operation
  * (HWCN layout)
  */
-inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I,
-                               const tvm::Tensor& W,
+inline tvm::top::Tensor conv2d_hwcn(const tvm::top::Tensor& I,
+                               const tvm::top::Tensor& W,
                                int pad_h = 0,
                                int pad_w = 0,
                                int stride_h = 1,
@@ -352,7 +353,7 @@ inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I,
         T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o),
         {i, kh, kw});
   };
-  return tvm::compute(output_shape, l, name, tag);
+  return tvm::top::compute(output_shape, l, name, tag);
 }
 
 
@@ -376,8 +377,8 @@ inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I,
  * \return A Tensor whose op member is the 2-D depthwise convolution operation
  * (NCHW layout)
  */
-inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
-                                         const tvm::Tensor& W,
+inline tvm::top::Tensor depthwise_conv2d_nchw(const tvm::top::Tensor& I,
+                                         const tvm::top::Tensor& W,
                                          int pad_h = 0,
                                          int pad_w = 0,
                                          int stride_h = 1,
@@ -406,11 +407,11 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
                     W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
                     {i, kh, kw});
   };
-  return tvm::compute(output_shape, l, name, tag);
+  return tvm::top::compute(output_shape, l, name, tag);
 }
 
-inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
-                                         const tvm::Tensor& W,
+inline tvm::top::Tensor depthwise_conv2d_nhwc(const tvm::top::Tensor& I,
+                                         const tvm::top::Tensor& W,
                                          int pad_h = 0,
                                          int pad_w = 0,
                                          int stride_h = 1,
@@ -439,7 +440,7 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
                     W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
                     {kh, kw, i});
   };
-  return tvm::compute(output_shape, l, name, tag);
+  return tvm::top::compute(output_shape, l, name, tag);
 }
 
 /*!
@@ -462,8 +463,8 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
  * \return A Tensor whose op member is the 2-D groupconvolution operation
  * (NCHW layout)
  */
-inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
-                                      const tvm::Tensor& W,
+inline tvm::top::Tensor group_conv2d_ngchw(const tvm::top::Tensor& I,
+                                      const tvm::top::Tensor& W,
                                       int pad_h = 0,
                                       int pad_w = 0,
                                       int stride_h = 1,
@@ -498,7 +499,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
         I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw),
         {i, kh, kw});
   };
-  return tvm::compute(output_shape, l, name, tag);
+  return tvm::top::compute(output_shape, l, name, tag);
 }
 
 }  // namespace topi
index 60b4175d65f83251c7ac8c13d96946908f2d6d64..a3bd96df77d820f821d27eddd1fe6429844fb203 100644 (file)
 #include <string>
 
 #include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief Creates an operation that calculates matrix multiplication in batch.
@@ -41,8 +42,8 @@ using namespace tvm;
 *
 * \return Tensor with shape [batch, M, N]
 */
-inline tvm::Tensor batch_matmul(const tvm::Tensor& x,
-                                const tvm::Tensor& y) {
+inline tvm::top::Tensor batch_matmul(const tvm::top::Tensor& x,
+                                const tvm::top::Tensor& y) {
   CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data";
   CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data";
 
@@ -52,7 +53,7 @@ inline tvm::Tensor batch_matmul(const tvm::Tensor& x,
   auto N = y->shape[1];
 
   auto k = tvm::reduce_axis(Range(0, K), "k");
-  auto result = tvm::compute(
+  auto result = tvm::top::compute(
       { batch, M, N },
       [&](Var b, Var i, Var j) {
         return tvm::sum(x(b, i, k) * y(b, j, k), { k });
index 148d892229db6fff4cff54b6472cfea64f185831..2d6f47ca8b3e0b178a2102f8eb64bf4e9338a556 100644 (file)
@@ -29,7 +29,7 @@
 #include "topi/tags.h"
 #include "topi/broadcast.h"
 #include "topi/transform.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/expr_operator.h"
 
 namespace topi {
@@ -43,7 +43,9 @@ namespace nn {
 * \param axis The axis to add the bias to.
 * \return Tensor with shape [batch, in_dim]
 */
-inline tvm::Tensor bias_add(const tvm::Tensor& data, const tvm::Tensor& bias, int axis) {
+inline tvm::top::Tensor bias_add(const tvm::top::Tensor& data,
+                                 const tvm::top::Tensor& bias,
+                                 int axis) {
   int data_ndim = data->shape.size();
   if (axis < 0) {
     axis += data_ndim;
index e2af3ae61518680ab7273e48e3dd1b2cd01dc965..7c920347b68fb55e5171051f7592a18c4ee3422a 100644 (file)
@@ -26,7 +26,7 @@
 
 #include <string>
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/ir_pass.h"
 #include "topi/tags.h"
 #include "topi/detail/constant_utils.h"
@@ -34,6 +34,7 @@
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief Binarization and bit-packing along a certain axis.
@@ -46,7 +47,7 @@ using namespace tvm;
 *
 * \return Output tensor with dtype uint32
 */
-inline tvm::Tensor binarize_pack(const tvm::Tensor& data,
+inline tvm::top::Tensor binarize_pack(const tvm::top::Tensor& data,
                                  int axis,
                                  std::string name = "PackedInput",
                                  std::string tag = "binarize_pack") {
@@ -62,7 +63,7 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data,
                      ishape[i]);
   }
 
-  return tvm::compute(
+  return tvm::top::compute(
     oshape,
     [&](const Array<Var>& indices) {
       Array<PrimExpr> start_idx;
@@ -98,8 +99,8 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data,
 *
 * \return Tensor with shape [batch, out_dim], dtype is float32
 */
-inline tvm::Tensor binary_dense(const tvm::Tensor& data,
-                                const tvm::Tensor& weight) {
+inline tvm::top::Tensor binary_dense(const tvm::top::Tensor& data,
+                                const tvm::top::Tensor& weight) {
   CHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data";
   CHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight";
   CHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data";
@@ -110,13 +111,13 @@ inline tvm::Tensor binary_dense(const tvm::Tensor& data,
   auto out_dim = weight->shape[0];
 
   auto k = tvm::reduce_axis(Range(0, in_dim), "k");
-  auto matmul = tvm::compute(
+  auto matmul = tvm::top::compute(
     { batch, out_dim },
     [&](Var i, Var j) {
       return tvm::sum(popcount(data(i, k) ^ weight(j, k)), { k });
     }, "tensor", "binary_dense");
 
-  return tvm::compute(
+  return tvm::top::compute(
     { batch, out_dim },
     [&](Var i, Var j) {
       return 32 * in_dim - 2.0f * matmul(i, j);
index f1bb07abf262081b88299d783e6722d70c24e010..7cdc8d7b5be334ccf7592583a7d3bdf778ae7f9c 100644 (file)
 #include <string>
 
 #include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief Creates an operation that calculates data * weight^T + bias
@@ -43,9 +44,9 @@ using namespace tvm;
 *
 * \return Tensor with shape [batch, out_dim]
 */
-inline tvm::Tensor dense(const tvm::Tensor& data,
-                         const tvm::Tensor& weight,
-                         const tvm::Tensor& bias,
+inline tvm::top::Tensor dense(const tvm::top::Tensor& data,
+                         const tvm::top::Tensor& weight,
+                         const tvm::top::Tensor& bias,
                          const DataType& out_dtype) {
   CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
   CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
@@ -58,7 +59,7 @@ inline tvm::Tensor dense(const tvm::Tensor& data,
   auto out_dim = weight->shape[0];
 
   auto k = tvm::reduce_axis(Range(0, in_dim), "k");
-  auto matmul = tvm::compute(
+  auto matmul = tvm::top::compute(
     { batch, out_dim },
     [&](Var i, Var j) {
       return tvm::sum(tvm::cast(out_dtype, data(i, k)) *
@@ -66,7 +67,7 @@ inline tvm::Tensor dense(const tvm::Tensor& data,
     }, "tensor", "dense");
 
   if (bias.defined()) {
-    matmul = tvm::compute(
+    matmul = tvm::top::compute(
       { batch, out_dim },
       [&](Var i, Var j) {
         return matmul(i, j) + tvm::cast(out_dtype, bias(j));
index 334b17054c812f6f47695654b91d537161520f59..6ffb3da25e52ac50192b374d98219d6fb1858326 100644 (file)
 
 #include <string>
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/ir_pass.h"
 #include "topi/tags.h"
 
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief Create a new expression of the logical and of all
@@ -79,7 +80,7 @@ inline Tensor dilate(const Tensor& x,
       (x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1)));
   }
 
-  return tvm::compute(
+  return tvm::top::compute(
     out_shape,
     [&](const Array<Var>& indices) {
       Array<PrimExpr> not_zero;
index 6b542f7c2afe3b32001382b13f93e18fe0a019fc..de11b6dd797ebc088539a3efc2602ee91181110d 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/constant_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/expr_operator.h"
 
 
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions.
@@ -64,7 +65,7 @@ inline Tensor flatten(const Tensor& x,
   }
   std::reverse(extra_shape.begin(), extra_shape.end());
 
-  return tvm::compute(
+  return tvm::top::compute(
     oshape, [&](Var i, Var j) {
       PrimExpr idx = j;
       std::vector<PrimExpr> index;
index 3f603594c5a9ecf345f595f8938e2d17aa8db1bb..6670e6da3fc30402fde7c026d9522fb90789285c 100644 (file)
 #include <string>
 #include <algorithm>
 #include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief L2 normalization inference operator
@@ -59,7 +60,7 @@ inline Tensor l2_normalize(const Tensor& data,
   Tensor sum_value = topi::sum(dot_value, axis, true);
   Tensor expand_sum = topi::broadcast_to(sum_value, input_shape);
   return topi::divide(data,
-                      topi::sqrt(tvm::compute(expand_sum->shape,
+                      topi::sqrt(tvm::top::compute(expand_sum->shape,
                                               [&](const Array<Var>& i){
                                                 return (max(expand_sum(i), eps));
                                               }, name, tag)));
index 0cce997c200947b1e8aad19f68b034de7c6051cd..cd3b9b2456e14724510ec2f1f4d2d7dfa63b9a51 100644 (file)
 #include <string>
 
 #include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief Local response normalization inference operator
@@ -67,21 +68,21 @@ inline Tensor lrn(const Tensor& data,
   auto rxs = tvm::reduce_axis(Range(0, size), "rxs");
   Tensor sqr_sum;
   if (axis == 1) {
-    sqr_sum = tvm::compute(input_shape,
+    sqr_sum = tvm::top::compute(input_shape,
                            [&](Var i, Var l, Var j, Var k) {
                            return tvm::sum(pad_data(i, l + rxs, j, k) *
                                            pad_data(i, l + rxs, j, k),
                                            {rxs});
                            });
   } else if (axis == 3) {
-    sqr_sum = tvm::compute(input_shape,
+    sqr_sum = tvm::top::compute(input_shape,
                            [&](Var i, Var l, Var j, Var k) {
                            return tvm::sum(pad_data(i, l, j, k + rxs) *
                                            pad_data(i, l, j, k + rxs),
                                            {rxs});
                            });
   }
-  auto sqrt_sum_up = tvm::compute(
+  auto sqrt_sum_up = tvm::top::compute(
       input_shape,
       [&](Var i, Var j, Var k, Var l) {
         return tvm::pow(bias +
index 03043ff9fe19a4966f72bbe5f8848f95d292477a..4cd2fe1b77087ea4e3b8128b6c45a0afb3fdaa37 100644 (file)
 #include <string>
 
 #include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief Scale and shift with NCHW order
@@ -49,7 +50,7 @@ inline Tensor scale_shift_nchw(const Tensor& x,
                                const Tensor& shift,
                                std::string name = "ScaleShift",
                                std::string tag = kBroadcast) {
-  return tvm::compute(
+  return tvm::top::compute(
     x->shape,
     [&](Var b, Var c, Var h, Var w) {
       return x(b, c, h, w) * scale(c) + shift(w);
@@ -72,7 +73,7 @@ inline Tensor scale_shift_nhwc(const Tensor& x,
                                const Tensor& shift,
                                std::string name = "ScaleShift",
                                std::string tag = kBroadcast) {
-  return tvm::compute(
+  return tvm::top::compute(
     x->shape,
     [&](Var b, Var h, Var w, Var c) {
       return x(b, h, w, c) * scale(c) + shift(w);
index a074ee1f6ef99f508520bc397b9f86f5e45cd0c7..ac284a0e01fe1fb675b49692e0f6a4635c3a842a 100644 (file)
@@ -37,6 +37,7 @@
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 
 /*! \brief Pooling type */
 enum PoolType : int {
@@ -124,7 +125,7 @@ inline Tensor pool_impl(const Tensor& x,
   if (pool_type == kMaxPool) {
     auto temp = do_pad ? pad(
         x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
-    return tvm::compute(out_shape, [&](const Array<Var>& output) {
+    return tvm::top::compute(out_shape, [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       indices.Set(height_axis, output[height_axis] * stride_height + dheight);
@@ -136,7 +137,7 @@ inline Tensor pool_impl(const Tensor& x,
     auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
 
     // TVM compute for summing the pooling window.
-    auto pool_sum = tvm::compute(out_shape,
+    auto pool_sum = tvm::top::compute(out_shape,
     [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
@@ -146,7 +147,7 @@ inline Tensor pool_impl(const Tensor& x,
     }, "tensor", "pool_sum");
 
     // TVM compute for dividing the reduced window sum by kernel size.
-    return tvm::compute(out_shape,
+    return tvm::top::compute(out_shape,
     [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
@@ -244,7 +245,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad,
         x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
 
     auto mp_argmax =
-        tvm::compute(out_shape,
+        tvm::top::compute(out_shape,
                      [&](const Array<Var>& inds) {
                        Array<PrimExpr> window_inds{inds.begin(), inds.end()};
                        window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
@@ -256,7 +257,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad,
 
     auto mp_inds = mp_argmax[0];
 
-    return tvm::compute(
+    return tvm::top::compute(
         x->shape,
         [&](const Array<Var>& inds) {
           Array<PrimExpr> pad_inds {inds.begin(), inds.end()};
@@ -287,7 +288,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad,
   } else if (pool_type == kAvgPool) {
     auto windowh = tvm::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
     auto windoww = tvm::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
-    return tvm::compute(
+    return tvm::top::compute(
         x->shape,
         [&](const Array<Var>& inds) {
           PrimExpr pad_h_idx = inds[height_axis] + pad_top;
@@ -512,7 +513,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
   out_shape.Set(width_axis, out_width);
 
   if (pool_type == kMaxPool) {
-    return tvm::compute(out_shape, [&](const Array<Var>& output) {
+    return tvm::top::compute(out_shape, [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       auto i_start_h = start_index(output[height_axis], out_height, height);
@@ -526,7 +527,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
       return tvm::max(x(indices), { dheight, dwidth });  // NOLINT(*)
     }, "tensor", "adaptive_pool_max");
   } else if (pool_type == kAvgPool) {
-    auto pool_sum = tvm::compute(out_shape, [&](const Array<Var>& output) {
+    auto pool_sum = tvm::top::compute(out_shape, [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       auto i_start_h = start_index(output[height_axis], out_height, height);
@@ -542,7 +543,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
       return tvm::sum(x(indices), { dheight, dwidth });
     }, "tensor", "adaptive_pool_sum");
 
-    return tvm::compute(out_shape, [&](const Array<Var>& output) {
+    return tvm::top::compute(out_shape, [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       auto i_start_h = start_index(output[height_axis], out_height, height);
@@ -696,7 +697,7 @@ inline Tensor pool_impl_nd(const Tensor& x,
   if (pool_type == kMaxPool) {
     auto temp = do_pad ? pad(
         x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
-    return tvm::compute(out_shape, [&](const Array<Var>& output) {
+    return tvm::top::compute(out_shape, [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
 
@@ -712,7 +713,7 @@ inline Tensor pool_impl_nd(const Tensor& x,
     auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
 
     // TVM compute for summing the pooling window.
-    auto pool_sum = tvm::compute(out_shape,
+    auto pool_sum = tvm::top::compute(out_shape,
     [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
@@ -725,7 +726,7 @@ inline Tensor pool_impl_nd(const Tensor& x,
     }, "tensor", "pool_sum");
 
     // TVM compute for dividing the reduced window sum by kernel size.
-    return tvm::compute(out_shape,
+    return tvm::top::compute(out_shape,
     [&](const Array<Var>& output) {
       Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
index 58ecc956964d073d4f8b2293955cf8eb0f442c8c..72e17454b724932336b9311c128bf20ee726b471 100644 (file)
 
 #include "topi/reduction.h"
 #include "topi/tags.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/expr_operator.h"
 
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief Softmax activation
@@ -109,14 +110,14 @@ inline Tensor softmax(const Tensor &x,
     return exp(indices) / expsum(non_reduce_indices);
   };
 
-  auto max_elem = tvm::compute(reduced_shape, _compute_max);
-  auto exp = tvm::compute(input_shape, [&](const Array<Var> &indices) {
+  auto max_elem = tvm::top::compute(reduced_shape, _compute_max);
+  auto exp = tvm::top::compute(input_shape, [&](const Array<Var> &indices) {
       return _compute_exp(max_elem, indices);
   });
-  auto expsum = tvm::compute(reduced_shape, [&](const Array<Var> &indices) {
+  auto expsum = tvm::top::compute(reduced_shape, [&](const Array<Var> &indices) {
       return _compute_expsum(exp, indices);
   });
-  return tvm::compute(input_shape, [&](const Array<Var> &indices) {
+  return tvm::top::compute(input_shape, [&](const Array<Var> &indices) {
       return _normalize(exp, expsum, indices);
   }, name, tag, attrs);
 }
@@ -139,16 +140,16 @@ inline Tensor log_softmax(const Tensor& x,
   PrimExpr n = x->shape[1];
 
   auto k = tvm::reduce_axis(Range(0, n), "k");
-  auto max_elem = tvm::compute(
+  auto max_elem = tvm::top::compute(
     { m }, [&](Var i) {
       return tvm::max(x(i, k), Array<IterVar>{ k }); });
   k = tvm::reduce_axis(Range(0, n), "k");
 
-  auto expsum = tvm::compute(
+  auto expsum = tvm::top::compute(
     { m }, [&](Var i) {
       return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), { k }); });
 
-  return tvm::compute(
+  return tvm::top::compute(
     x->shape, [&](Var i, Var j) {
       return x(i, j) - max_elem(i) - tvm::log(expsum(i));
     }, name, tag);
index b6230c7c017cd38a7cc6370fffb2cb6d8825b4d7..f624f3687a40f99ecdf1cc92467d9bbaa8a07ae7 100644 (file)
@@ -34,6 +34,7 @@
 namespace topi {
 namespace nn {
 using namespace tvm;
+using namespace tvm::top;
 using namespace topi::image;
 
 /*!
index ac843b1a707791b899328652b2854e6b3a53b6d8..197ef2b152c415eab8fcafc76fcb5ae430ef8175 100644 (file)
 #include "topi/transform.h"
 #include "topi/detail/ravel_unravel.h"
 #include "topi/detail/constant_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/expr_operator.h"
 
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 /*! \brief The operation to use for CommReduce */
 using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis)>;
@@ -167,7 +168,7 @@ inline Tensor DoCommReduce(const Tensor& data,
     return func(data(eval_range), r_axes);
   };
 
-  return tvm::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
+  return tvm::top::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
 }
 
 /*!
@@ -251,11 +252,11 @@ inline Tensor CommReduceIdx(const Tensor& data,
     return func({ idx, data(eval_range) }, reduce_axes, nullptr);
   };
 
-  auto temp_idx_val = tvm::compute(target_shape, compute,
+  auto temp_idx_val = tvm::top::compute(target_shape, compute,
                                    data->op->name + "_red_temp", kCommReduceIdx);
   auto temp_idx = temp_idx_val[0];
   auto temp_val = temp_idx_val[1];
-  return tvm::compute(
+  return tvm::top::compute(
     target_shape,
     [&temp_idx](const Array<Var>& indices) { return temp_idx(indices); },
     data->op->name + "_red",
index 7eb9f9a519625dae95b99e3da2c7803066ac1abd..00297489cd281f251d98ceae05d1a0fe6e625257 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TOPI_ROCM_DENSE_H_
 #define TOPI_ROCM_DENSE_H_
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 #include "topi/tags.h"
 #include "topi/detail/array_utils.h"
@@ -35,6 +35,7 @@
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace rocm {
 /*!
@@ -48,10 +49,10 @@ namespace rocm {
 *
 * \return Tensor with shape [batch, out_dim]
 */
-inline tvm::Tensor dense_rocm(const Target& target,
-                              const tvm::Tensor& data,
-                              const tvm::Tensor& weight,
-                              const tvm::Tensor& bias,
+inline tvm::top::Tensor dense_rocm(const Target& target,
+                              const tvm::top::Tensor& data,
+                              const tvm::top::Tensor& weight,
+                              const tvm::top::Tensor& bias,
                               const DataType& out_dtype) {
   CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
   CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
@@ -67,7 +68,7 @@ inline tvm::Tensor dense_rocm(const Target& target,
     CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
     auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
     if (bias.defined()) {
-      mm = tvm::compute({ batch, out_dim },
+      mm = tvm::top::compute({ batch, out_dim },
                         [&](Var i, Var j) {
                           return mm(i, j) + bias(j);
                         }, "tensor", kBroadcast);
index 7b367b713584f19455eb6b623dc9c0f6f512e5dc..b2999f4d8bb5b6558a8e2a4413373899a4e5f55c 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 #include "topi/cuda/injective.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace rocm {
 
@@ -41,7 +42,7 @@ namespace rocm {
  *
  * \param sch The schedule to update.
  * \param out The tensor representing the injective op.
- * 
+ *
  * \return The updated schedule.
  */
 inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
index c6ffd328528a88cd78742b359f96dd84a9e110ed..7dde9dc77140f84acaf8672db9d8a9b5346de2c0 100644 (file)
 #ifndef TOPI_ROCM_NORMALIZATION_H_
 #define TOPI_ROCM_NORMALIZATION_H_
 
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 #include "topi/tags.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 namespace rocm {
 /*!
 * \brief Create a rocm schedule for LRN
index b0ab7cd40650633e030be9d9cb729d544de33882..14f0aa0c66ce771e604c7ff4f0087cfdcf80efbd 100644 (file)
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
 #include "topi/detail/array_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 #include "topi/cuda/pooling.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace rocm {
 
index 22e7c7d48d3413dac00844d4f1bd491063da7e56..4b788ee603cdc8b5ed86a35706ad9ec8e900c46c 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 #include "topi/cuda/reduction.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace rocm {
 /*!
index fbc67eaf9f2ddef44ddbc8b4d7838153aed42534..43f27318a4df927e371c7bcd62b5249d5abcf9f8 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 #include "topi/cuda/softmax.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace rocm {
 
index 66e2773ded7e2e3a8b17acb8d76640a9ce706279..41a64ebb45ad12aa14dcad1f1c12ef482116a321 100644 (file)
 #include "topi/detail/ravel_unravel.h"
 #include "topi/detail/constant_utils.h"
 #include "topi/detail/tensor_utils.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/expr_operator.h"
 #include "tvm/data_layout.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 using namespace topi::detail;
 
 /*!
@@ -1041,8 +1042,8 @@ inline Tensor gather_nd(const Tensor& data,
  *
  * \return A Tensor whose op member is the matmul operation
  */
-inline tvm::Tensor matmul(const tvm::Tensor& A,
-                           const tvm::Tensor& B,
+inline tvm::top::Tensor matmul(const tvm::top::Tensor& A,
+                           const tvm::top::Tensor& B,
                            bool trans_a = false,
                            bool trans_b = false,
                            std::string name = "T_matmul",
@@ -1054,7 +1055,7 @@ inline tvm::Tensor matmul(const tvm::Tensor& A,
     return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]),
                     {k});
   };
-  return tvm::compute(output_shape, l, name, tag);
+  return tvm::top::compute(output_shape, l, name, tag);
 }
 
 /*!
@@ -1069,7 +1070,7 @@ inline tvm::Tensor matmul(const tvm::Tensor& A,
  * \return A Tensor computing the result
  */
 inline Tensor tensordot(const Tensor& A,
-                        const tvm::Tensor& B,
+                        const tvm::top::Tensor& B,
                         int axes = 2,
                         std::string name = "T_tensordot",
                         std::string tag = kMatMul) {
@@ -1124,7 +1125,7 @@ inline Tensor tensordot(const Tensor& A,
  * \return A Tensor computing the result
  */
 inline Tensor tensordot(const Tensor& A,
-                        const tvm::Tensor& B,
+                        const tvm::top::Tensor& B,
                         Array<PrimExpr> A_axes,
                         Array<PrimExpr> B_axes,
                         std::string name = "T_tensordot",
index df3fadef7d751aa62b2d97ac1778fae2ed8f92ac..c5ddea9fdeb058b8d2c0c51e0da289b3a5bd91b7 100644 (file)
 #include "topi/reduction.h"
 #include "topi/tags.h"
 #include "topi/transform.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/expr_operator.h"
 
 namespace topi {
 namespace vision {
 using namespace tvm;
+using namespace tvm::top;
 
 /*!
 * \brief Reorg operation
@@ -60,7 +61,7 @@ inline Tensor reorg(const Tensor &data,
   int w_in = GetConstInt(input_shape[3]);
   int out_c = c_in / (stride * stride);
 
-  auto out = tvm::compute(input_shape,
+  auto out = tvm::top::compute(input_shape,
                           [&](Var b, Var k, Var j, Var i) {
                           return data(b * stride * stride,
                                       indexmod(k, out_c) * stride * stride,
index 1158749520c53da84c72caffb39436c80fe24dd9..578360e9a010315210a7cfdd2fd78a9ce36cd2bf 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace x86 {
 /*!
index 294871b535df175a5cef88078e77b70c39bffb19..e3ac0110ea492ff34d3b1fa116648317f6810062 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace x86 {
 /*!
@@ -54,7 +55,7 @@ inline Schedule MakeDefaultSchedule(const Target &target,
   auto axis = s[x]->op.as<ComputeOpNode>()->axis;
 
   if (auto_inline) {
-    tvm::schedule::AutoInlineInjective(s);
+    tvm::top::AutoInlineInjective(s);
     if (axis.size() > 0) {
       detail::Fuse(s[x], axis);
     }
index 7cb79ae322c2d0c4add54fcd1b35a94189ff4ba6..12c413e30f11ac97cd598f0d5ea3754d74dd0285 100644 (file)
 
 #include "topi/tags.h"
 #include "topi/detail/fuse.h"
-#include "tvm/operation.h"
+#include "tvm/top/operation.h"
 #include "tvm/build_module.h"
 
 namespace topi {
 using namespace tvm;
+using namespace tvm::top;
 
 namespace x86 {
 
@@ -39,7 +40,7 @@ namespace x86 {
  *
  * \param sch The schedule to update.
  * \param out The tensor representing the injective op.
- * 
+ *
  * \return The updated schedule.
  */
 inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
@@ -69,7 +70,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
     out_ops.push_back(t->op);
   }
   auto s = create_schedule(out_ops);
-  tvm::schedule::AutoInlineInjective(s);
+  tvm::top::AutoInlineInjective(s);
 
   auto x = outs[0];
   schedule_injective_from_existing(s, x);
index 21e61cfa4816fe9670e9049df728ea9b3556aadf..8197e891d15472a04f45c882aeeffe3bcc66cb01 100644 (file)
@@ -93,7 +93,7 @@ Array<Integer> ArrayOrInt(TVMArgValue arg) {
 inline bool IsTensorType(TVMArgValue arg) {
   return (arg.type_code() == kTVMObjectHandle &&
           static_cast<Object*>(
-              arg.value().v_handle)->IsInstance<tvm::TensorNode>());
+              arg.value().v_handle)->IsInstance<tvm::top::TensorNode>());
 }
 
 
@@ -109,13 +109,13 @@ TVM_REGISTER_GLOBAL("topi.TEST_create_target")
       bool lhs_is_tensor = IsTensorType(args[0]);                       \
       bool rhs_is_tensor = IsTensorType(args[1]);                       \
       if (lhs_is_tensor && rhs_is_tensor) {                             \
-        *rv = Op(args[0].operator tvm::Tensor(),                        \
-                 args[1].operator tvm::Tensor());                       \
+        *rv = Op(args[0].operator tvm::top::Tensor(),                        \
+                 args[1].operator tvm::top::Tensor());                       \
       } else if (!lhs_is_tensor && rhs_is_tensor) {                     \
         *rv = Op(args[0].operator tvm::PrimExpr(),                          \
-                 args[1].operator tvm::Tensor());                       \
+                 args[1].operator tvm::top::Tensor());                       \
       } else if (lhs_is_tensor && !rhs_is_tensor) {                     \
-        *rv = Op(args[0].operator tvm::Tensor(),                        \
+        *rv = Op(args[0].operator tvm::top::Tensor(),                        \
                  args[1].operator tvm::PrimExpr());                         \
       } else if (!lhs_is_tensor && !rhs_is_tensor) {                    \
         *rv = Op(args[0].operator tvm::PrimExpr(),                          \
@@ -757,7 +757,7 @@ TVM_REGISTER_GLOBAL("topi.util.is_empty_shape")
 
 /*! \brief Builder function for instantiating schedules. */
 using FTVMScheduleBuilder = std::function<
-  tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>;
+  tvm::top::Schedule(const tvm::Target& target, const tvm::Array<tvm::top::Tensor>& outs)>;
 
 /*!
  * \brief Helper function for registering generic functions matching the
@@ -826,7 +826,7 @@ TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
 
 /*! \brief Builder function for instantiating schedules from existing schedules. */
 using FTVMScheduleFromExistingBuilder = std::function<
-  tvm::Schedule(tvm::Schedule sch, const tvm::Tensor& out)>;
+  tvm::top::Schedule(tvm::top::Schedule sch, const tvm::top::Tensor& out)>;
 
 /*!
  * \brief Helper function for registering generic functions matching the
@@ -850,10 +850,10 @@ TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
   topi::cuda::schedule_injective_from_existing));
 
 /*! \brief Builder function for instantiating dense ops. */
-using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
-                                                     const tvm::Tensor& data,
-                                                     const tvm::Tensor& weight,
-                                                     const tvm::Tensor& bias,
+using FTVMDenseOpBuilder = std::function<tvm::top::Tensor(const Target& target,
+                                                     const tvm::top::Tensor& data,
+                                                     const tvm::top::Tensor& weight,
+                                                     const tvm::top::Tensor& bias,
                                                      const DataType& out_dtype)>;
 
 /*!
@@ -879,9 +879,9 @@ inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
 
 TVM_REGISTER_GENERIC_FUNC(dense)
 .set_default(WrapDenseOp([](const Target& target,
-                            const tvm::Tensor& data,
-                            const tvm::Tensor& weight,
-                            const tvm::Tensor& bias,
+                            const tvm::top::Tensor& data,
+                            const tvm::top::Tensor& weight,
+                            const tvm::top::Tensor& bias,
                             const DataType& out_dtype) {
   return topi::nn::dense(data, weight, bias, out_dtype);
 }))