[Relay][AutoTVM] Relay op strategy (#4644)
authorHaichen Shen <shenhaichen@gmail.com>
Mon, 24 Feb 2020 21:12:03 +0000 (13:12 -0800)
committerGitHub <noreply@github.com>
Mon, 24 Feb 2020 21:12:03 +0000 (13:12 -0800)
* relay op strategy

fix lint

bitpack strategy

bitserial_dense (#6)

* update strategy

* address comments

fix a few topi test

Dense strategy (#5)

* dense

* add biforst; remove comments

* address comment

Refactor x86 conv2d_NCHWc (#4)

* Refactor x86 conv2d

* Add x86 depthwise_conv2d_NCHWc

* Add back topi x86 conv2d_nchw

* Merge x86 conv2d_nchw and conv2d_NCHWc

* Minor fix for x86 conv2d

fix more strategy

Add x86 conv2d_NCHWc_int8 strategy (#8)

* Add x86 conv2d_NCHWc_int8 strategy

* Remove contrib_conv2d_nchwc_int8

* Fix generic conv2d_NCHWc for int8

* Fix topi arm_cpu conv2d_NCHWc_int8

update x86 conv2d

enable specify relay ops to be tuned for autotvm

add cuda conv2d strategy

add conv2d strategy for rocm

add conv2d strategy for hls

add conv2d strategy for arm cpu

add conv2d strategy for mali

add conv2d strategy for bifrost

add conv2d strategy for intel graphics

clean up and fix lint

remove template keys from autotvm

remove 2 in the func name

address comments

fix

* fix bugs

* lint

* address comments

* add name to op implement

* Modify topi tests (#9)

* Add pooling, reorg, softmax and vision

* Add lrn

* fix topi test

* fix more topi test

* lint

* address comments

* x

* fix more tests & bugs

* Modify more tests (#10)

* Modify tests for bitserial_conv2d, bitserial_dense, bitserial_conv2d_rasp and bnn

* Minor fix

* More minor fix

* fix more test

* try to update vta using strategy

* fix cpptest

* x

* fix rebase err

* Fix two tests (#11)

* change autotvm log format

* lint

* minor fix

* try fix vta test

* fix rebase err

* tweak

* tmp hack for vta pass

* fix tutorial

* fix

* fix more tutorials

* fix vta tutorial

* minor

* address comments

* fix

* address comments

* fix cpptest

* fix docs

* change data structure name and api

* address comments

* lint

* fix rebase err

* updates

* fix winograd test

* fix doc

* rebase

* upgrade tophub version number

* fix bug

* re-enable vta tsim test after tophub is upgraded

* fix vta test to use the correct args so the config can be found in tophub

Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
270 files changed:
include/tvm/relay/op_attr_types.h
include/tvm/relay/op_strategy.h [new file with mode: 0644]
include/tvm/te/schedule.h
python/tvm/autotvm/__init__.py
python/tvm/autotvm/database.py
python/tvm/autotvm/feature.py
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
python/tvm/autotvm/graph_tuner/utils/utils.py
python/tvm/autotvm/record.py
python/tvm/autotvm/task/__init__.py
python/tvm/autotvm/task/dispatcher.py
python/tvm/autotvm/task/relay_integration.py
python/tvm/autotvm/task/space.py
python/tvm/autotvm/task/task.py
python/tvm/autotvm/task/topi_integration.py
python/tvm/autotvm/tophub.py
python/tvm/autotvm/tuner/xgboost_cost_model.py
python/tvm/relay/backend/compile_engine.py
python/tvm/relay/expr_functor.py
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/frontend/tflite.py
python/tvm/relay/memory_alloc.py
python/tvm/relay/op/__init__.py
python/tvm/relay/op/_algorithm.py
python/tvm/relay/op/_reduce.py
python/tvm/relay/op/_tensor.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/annotation/annotation.py
python/tvm/relay/op/contrib/_contrib.py
python/tvm/relay/op/image/_image.py
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/nn/nn.py
python/tvm/relay/op/op.py
python/tvm/relay/op/strategy/__init__.py [new file with mode: 0644]
python/tvm/relay/op/strategy/arm_cpu.py [new file with mode: 0644]
python/tvm/relay/op/strategy/bifrost.py [new file with mode: 0644]
python/tvm/relay/op/strategy/cuda.py [new file with mode: 0644]
python/tvm/relay/op/strategy/generic.py [new file with mode: 0644]
python/tvm/relay/op/strategy/hls.py [new file with mode: 0644]
python/tvm/relay/op/strategy/intel_graphics.py [new file with mode: 0644]
python/tvm/relay/op/strategy/mali.py [new file with mode: 0644]
python/tvm/relay/op/strategy/opengl.py [new file with mode: 0644]
python/tvm/relay/op/strategy/rocm.py [new file with mode: 0644]
python/tvm/relay/op/strategy/x86.py [new file with mode: 0644]
python/tvm/relay/op/vision/_rcnn.py
python/tvm/relay/op/vision/_vision.py
python/tvm/relay/op/vision/_yolo.py
python/tvm/relay/quantize/_annotate.py
python/tvm/relay/testing/mobilenet.py
python/tvm/target/generic_func.py
python/tvm/te/__init__.py
python/tvm/te/schedule.py
python/tvm/tir/expr.py
src/relay/backend/compile_engine.cc
src/relay/backend/compile_engine.h
src/relay/ir/op_strategy.cc [new file with mode: 0644]
src/relay/op/annotation/annotation.cc
src/relay/op/debug.cc
src/relay/op/memory/memory.cc
src/relay/op/nn/convolution.cc
src/relay/op/nn/convolution.h
src/relay/op/nn/nn.cc
src/relay/op/nn/pad.cc
src/relay/op/nn/pooling.cc
src/relay/op/tensor/binary.cc
src/relay/op/tensor/reduce.cc
src/relay/op/tensor/transform.cc
src/relay/op/tensor/unary.cc
src/relay/op/vision/yolo.cc
src/relay/pass/alter_op_layout.cc
src/te/schedule/schedule_lang.cc
tests/cpp/relay_build_module_test.cc
tests/python/frontend/mxnet/test_forward.py
tests/python/integration/test_tuning.py
tests/python/relay/test_any.py
tests/python/relay/test_autotvm_task_extraction.py
tests/python/relay/test_backend_compile_engine.py
tests/python/relay/test_op_level2.py
tests/python/relay/test_op_qnn_conv2d.py
tests/python/relay/test_pass_alter_op_layout.py
tests/python/relay/test_pass_auto_quantize.py
tests/python/relay/test_pass_fold_scale_axis.py
tests/python/unittest/test_autotvm_common.py
tests/python/unittest/test_autotvm_dispatch_context.py
tests/python/unittest/test_autotvm_measure.py
tests/python/unittest/test_codegen_blob.py
tests/python/unittest/test_codegen_cuda.py
tests/python/unittest/test_graph_tuner_core.py
tests/python/unittest/test_graph_tuner_utils.py
tests/python/unittest/test_lang_tensor_overload_op.py
topi/include/topi/cuda/normalization.h
topi/include/topi/rocm/normalization.h
topi/python/topi/__init__.py
topi/python/topi/argwhere.py
topi/python/topi/arm_cpu/__init__.py
topi/python/topi/arm_cpu/bitserial_conv2d.py
topi/python/topi/arm_cpu/bitserial_dense.py
topi/python/topi/arm_cpu/conv2d.py
topi/python/topi/arm_cpu/conv2d_alter_op.py [new file with mode: 0644]
topi/python/topi/arm_cpu/conv2d_int8.py
topi/python/topi/arm_cpu/conv2d_spatial_pack.py
topi/python/topi/arm_cpu/conv2d_transpose.py
topi/python/topi/arm_cpu/depthwise_conv2d.py
topi/python/topi/arm_cpu/injective.py
topi/python/topi/bifrost/conv2d.py
topi/python/topi/bifrost/dense.py
topi/python/topi/bifrost/depthwise_conv2d.py
topi/python/topi/cuda/__init__.py
topi/python/topi/cuda/batch_matmul.py
topi/python/topi/cuda/conv1d.py
topi/python/topi/cuda/conv1d_transpose_ncw.py
topi/python/topi/cuda/conv2d.py
topi/python/topi/cuda/conv2d_alter_op.py [new file with mode: 0644]
topi/python/topi/cuda/conv2d_direct.py
topi/python/topi/cuda/conv2d_hwcn.py
topi/python/topi/cuda/conv2d_int8.py
topi/python/topi/cuda/conv2d_transpose_nchw.py
topi/python/topi/cuda/conv2d_winograd.py
topi/python/topi/cuda/conv3d.py
topi/python/topi/cuda/conv3d_direct.py
topi/python/topi/cuda/deformable_conv2d.py
topi/python/topi/cuda/dense.py
topi/python/topi/cuda/depthwise_conv2d.py
topi/python/topi/cuda/group_conv2d_nchw.py
topi/python/topi/cuda/injective.py
topi/python/topi/cuda/nms.py
topi/python/topi/cuda/nn.py
topi/python/topi/cuda/pooling.py
topi/python/topi/cuda/rcnn/__init__.py
topi/python/topi/cuda/rcnn/proposal.py
topi/python/topi/cuda/reduction.py
topi/python/topi/cuda/softmax.py
topi/python/topi/cuda/sort.py
topi/python/topi/cuda/ssd/multibox.py
topi/python/topi/cuda/vision.py
topi/python/topi/generic/conv2d.py
topi/python/topi/generic/extern.py
topi/python/topi/generic/injective.py
topi/python/topi/generic/nn.py
topi/python/topi/generic/search.py
topi/python/topi/generic/sort.py
topi/python/topi/generic/vision.py
topi/python/topi/hls/injective.py
topi/python/topi/hls/nn.py
topi/python/topi/intel_graphics/__init__.py
topi/python/topi/intel_graphics/conv2d.py
topi/python/topi/intel_graphics/conv2d_alter_op.py [new file with mode: 0644]
topi/python/topi/intel_graphics/depthwise_conv2d.py
topi/python/topi/mali/conv2d.py
topi/python/topi/mali/dense.py
topi/python/topi/mali/depthwise_conv2d.py
topi/python/topi/nn/batch_matmul.py
topi/python/topi/nn/bitserial_conv2d.py
topi/python/topi/nn/bitserial_dense.py
topi/python/topi/nn/conv1d.py
topi/python/topi/nn/conv1d_transpose.py
topi/python/topi/nn/conv2d.py
topi/python/topi/nn/conv2d_transpose.py
topi/python/topi/nn/conv3d.py
topi/python/topi/nn/deformable_conv2d.py
topi/python/topi/nn/dense.py
topi/python/topi/nn/depthwise_conv2d.py
topi/python/topi/nn/local_response_norm.py
topi/python/topi/nn/sparse.py
topi/python/topi/nn/util.py
topi/python/topi/opengl/conv2d_nchw.py
topi/python/topi/opengl/dense.py
topi/python/topi/opengl/injective.py
topi/python/topi/opengl/pooling.py
topi/python/topi/opengl/softmax.py
topi/python/topi/rocm/conv2d.py
topi/python/topi/rocm/dense.py
topi/python/topi/rocm/nn.py
topi/python/topi/sort.py
topi/python/topi/testing/__init__.py
topi/python/topi/testing/common.py [new file with mode: 0644]
topi/python/topi/vision/nms.py
topi/python/topi/vision/rcnn/proposal.py
topi/python/topi/vision/rcnn/roi_align.py
topi/python/topi/vision/rcnn/roi_pool.py
topi/python/topi/vision/reorg.py
topi/python/topi/vision/ssd/multibox.py
topi/python/topi/x86/__init__.py
topi/python/topi/x86/batch_matmul.py
topi/python/topi/x86/binarize_pack.py
topi/python/topi/x86/binary_dense.py
topi/python/topi/x86/bitserial_conv2d.py
topi/python/topi/x86/bitserial_dense.py
topi/python/topi/x86/conv1d.py
topi/python/topi/x86/conv2d.py
topi/python/topi/x86/conv2d_alter_op.py
topi/python/topi/x86/conv2d_avx_1x1.py
topi/python/topi/x86/conv2d_avx_common.py
topi/python/topi/x86/conv2d_int8.py
topi/python/topi/x86/conv2d_transpose.py
topi/python/topi/x86/conv3d.py
topi/python/topi/x86/dense.py
topi/python/topi/x86/depthwise_conv2d.py
topi/python/topi/x86/injective.py
topi/python/topi/x86/nn.py
topi/python/topi/x86/pooling.py
topi/python/topi/x86/reduction.py
topi/python/topi/x86/roi_align.py
topi/python/topi/x86/sparse.py
topi/src/topi.cc
topi/tests/python/common.py
topi/tests/python/test_fifo_buffer.py
topi/tests/python/test_topi_batch_matmul.py
topi/tests/python/test_topi_bitserial_conv2d.py
topi/tests/python/test_topi_bitserial_conv2d_rasp.py
topi/tests/python/test_topi_bitserial_dense.py
topi/tests/python/test_topi_bnn.py
topi/tests/python/test_topi_broadcast.py
topi/tests/python/test_topi_clip.py
topi/tests/python/test_topi_conv1d.py
topi/tests/python/test_topi_conv1d_transpose_ncw.py
topi/tests/python/test_topi_conv2d_NCHWc.py
topi/tests/python/test_topi_conv2d_hwcn.py
topi/tests/python/test_topi_conv2d_int8.py
topi/tests/python/test_topi_conv2d_nchw.py
topi/tests/python/test_topi_conv2d_nhwc.py
topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py
topi/tests/python/test_topi_conv2d_transpose_nchw.py
topi/tests/python/test_topi_conv2d_winograd.py
topi/tests/python/test_topi_conv3d_ncdhw.py
topi/tests/python/test_topi_conv3d_ndhwc.py
topi/tests/python/test_topi_deformable_conv2d.py
topi/tests/python/test_topi_dense.py
topi/tests/python/test_topi_depth_to_space.py
topi/tests/python/test_topi_depthwise_conv2d.py
topi/tests/python/test_topi_group_conv2d.py
topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py
topi/tests/python/test_topi_image.py
topi/tests/python/test_topi_lrn.py
topi/tests/python/test_topi_math.py
topi/tests/python/test_topi_pooling.py
topi/tests/python/test_topi_reduce.py
topi/tests/python/test_topi_relu.py
topi/tests/python/test_topi_reorg.py
topi/tests/python/test_topi_softmax.py
topi/tests/python/test_topi_sort.py
topi/tests/python/test_topi_space_to_depth.py
topi/tests/python/test_topi_tensor.py
topi/tests/python/test_topi_transform.py
topi/tests/python/test_topi_upsampling.py
topi/tests/python/test_topi_vision.py
tutorials/autotvm/tune_conv2d_cuda.py
tutorials/autotvm/tune_relay_arm.py
tutorials/autotvm/tune_relay_cuda.py
tutorials/autotvm/tune_relay_mobile_gpu.py
tutorials/autotvm/tune_relay_x86.py
tutorials/autotvm/tune_simple_template.py
tutorials/dev/relay_pass_infra.py
tutorials/optimize/opt_matmul_auto_tensorcore.py
tutorials/topi/intro_topi.py
vta/python/vta/ir_pass.py
vta/python/vta/top/__init__.py
vta/python/vta/top/bitpack.py
vta/python/vta/top/op.py
vta/python/vta/top/vta_conv2d.py
vta/python/vta/top/vta_conv2d_transpose.py
vta/python/vta/top/vta_dense.py
vta/python/vta/top/vta_group_conv2d.py
vta/scripts/tune_resnet.py
vta/tests/python/integration/test_benchmark_topi_conv2d.py
vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py
vta/tests/python/integration/test_benchmark_topi_dense.py
vta/tests/python/integration/test_benchmark_topi_group_conv2d.py
vta/tutorials/autotvm/tune_relay_vta.py

index 88e948f..1a2263e 100644 (file)
@@ -29,6 +29,7 @@
 #include <tvm/relay/type.h>
 #include <tvm/relay/expr.h>
 #include <tvm/target/target.h>
+#include <tvm/target/generic_func.h>
 #include <tvm/tir/data_layout.h>
 #include <string>
 
@@ -105,9 +106,8 @@ using TShapeDataDependant = bool;
  */
 using FTVMCompute = runtime::TypedPackedFunc<
   Array<te::Tensor>(const Attrs& attrs,
-                     const Array<te::Tensor>& inputs,
-                     const Type& out_type,
-                     const Target& target)>;
+                    const Array<te::Tensor>& inputs,
+                    const Type& out_type)>;
 
 /*!
  * \brief Build the computation schedule for
@@ -120,8 +120,18 @@ using FTVMCompute = runtime::TypedPackedFunc<
  */
 using FTVMSchedule = runtime::TypedPackedFunc<
   te::Schedule(const Attrs& attrs,
-                const Array<te::Tensor>& outs,
-                const Target& target)>;
+               const Array<te::Tensor>& outs,
+               const Target& target)>;
+
+/*!
+ * \brief Generate the strategy of operators. This function is a generic
+ * function and can be re-defined for different targets.
+ *
+ * The function signature of generic function is:
+ *   OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
+ *              const Type& out_type, const Target& target)
+ */
+using FTVMStrategy = GenericFunc;
 
 /*!
  * \brief Alternate the layout of operators or replace the
@@ -136,7 +146,8 @@ using FTVMSchedule = runtime::TypedPackedFunc<
 using FTVMAlterOpLayout = runtime::TypedPackedFunc<
   Expr(const Attrs& attrs,
        const Array<Expr>& args,
-       const Array<te::Tensor>& tinfos)>;
+       const Array<te::Tensor>& tinfos,
+       const Type& out_type)>;
 
 /*!
  * \brief Convert the layout of operators or replace the
@@ -191,9 +202,7 @@ using FForwardRewrite = runtime::TypedPackedFunc<
  * \brief Gradient for a specific op.
  *
  * \param orig_call the original Expr.
- *
  * \param output_grad the gradient of the Expr.
- *
  * \return the gradient for each parameters.
  */
 using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
@@ -207,13 +216,13 @@ enum AnyCodegenStrategy {
   kVariableDimensions
 };
 
-/* \brief A runtime representation of shape. */
+/*! \brief A runtime representation of shape. */
 using Shape = Array<IndexExpr>;
 
 using FShapeFunc = runtime::TypedPackedFunc<
   Array<te::Tensor>(const Attrs& attrs,
-                     const Array<te::Tensor>& inputs,
-                     const Array<IndexExpr>& out_ndims)>;
+                    const Array<te::Tensor>& inputs,
+                    const Array<IndexExpr>& out_ndims)>;
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/include/tvm/relay/op_strategy.h b/include/tvm/relay/op_strategy.h
new file mode 100644 (file)
index 0000000..a4da95a
--- /dev/null
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/op_strategy.h
+ * \brief The Relay operator Strategy and related data structure.
+ */
+
+#ifndef TVM_RELAY_OP_STRATEGY_H_
+#define TVM_RELAY_OP_STRATEGY_H_
+
+#include <tvm/te/tensor.h>
+#include <tvm/te/schedule.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/target/target.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Operator implementation that includes compute and schedule function.
+ */
+class OpImplementationNode : public Object {
+ public:
+  /*! \brief Compute function */
+  FTVMCompute fcompute;
+  /*! \brief Schedule function */
+  FTVMSchedule fschedule;
+  /*! \brief Name of the implementation */
+  std::string name;
+  /*! \brief Priority level */
+  int plevel;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("plevel", &plevel);
+  }
+
+  static constexpr const char* _type_key = "relay.OpImplementation";
+  TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object);
+};
+
+/*!
+ * \brief Operator implementation class.
+ */
+class OpImplementation : public ObjectRef {
+ public:
+  /*!
+   * \brief Invoke the operator compute function.
+   * \param attrs The attribute of the primitive
+   * \param inputs The input tensors.
+   * \param out_type The output type information.
+   * \return The output compute description of the operator.
+   */
+  TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs,
+                                    const Array<te::Tensor>& inputs,
+                                    const Type& out_type);
+  /*!
+   * \brief Build the computation schedule.
+   * \param attrs The attribute of the node.
+   * \param outs The output tensors.
+   * \param target The build target.
+   * \return The computation schedule.
+   */
+  TVM_DLL te::Schedule Schedule(const Attrs& attrs,
+                                const Array<te::Tensor>& outs,
+                                const Target& target);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode);
+};
+
+/*!
+ * \brief Specialized implementations for operators under certain conditions.
+ */
+class OpSpecializationNode : public Object {
+ public:
+  /*! \brief List of implementations. */
+  Array<OpImplementation> implementations;
+  /*! \brief Condition to enable the specialization.
+   *    Could be undefined to represent generic case. */
+  te::SpecializedCondition condition;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("condition", &condition);
+    v->Visit("implementations", &implementations);
+  }
+
+  static constexpr const char* _type_key = "relay.OpSpecialization";
+  TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);
+};
+
+/*!
+ * \brief Operator specialization class.
+ */
+class OpSpecialization : public ObjectRef {
+ public:
+  /*!
+   * \brief Add an implementation.
+   * \param fcompute Compute function
+   * \param fschedule Schedule function
+   * \param name Name of the implementation
+   * \param plevel Priority level of the implementation
+   */
+  TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
+                                 std::string name, int plevel);
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
+};
+
+/*!
+ * \brief Operator strategy to choose implementation.
+ */
+class OpStrategyNode : public Object {
+ public:
+  /*! \brief List of operator specializations. */
+  Array<OpSpecialization> specializations;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("specializations", &specializations);
+  }
+
+  static constexpr const char* _type_key = "relay.OpStrategy";
+  TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
+};
+
+/*!
+ * \brief Operator strategy class.
+ */
+class OpStrategy : public ObjectRef {
+ public:
+  /*!
+   * \brief Add an implementation.
+   * \param fcompute Compute function
+   * \param fschedule Schedule function
+   * \param name Name of the implementation
+   * \param plevel Priority level of the implementation
+   */
+  TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
+                                 std::string name, int plevel);
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
+};
+
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_OP_STRATEGY_H_
index e99b54a..a8a0236 100644 (file)
@@ -28,6 +28,7 @@
 #include <tvm/tir/expr.h>
 #include <tvm/te/tensor.h>
 #include <tvm/te/tensor_intrin.h>
+#include <tvm/support/with.h>
 
 #include <string>
 #include <unordered_map>
@@ -742,6 +743,53 @@ class SingletonNode : public IterVarRelationNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
 };
 
+/*! \brief Container for specialization conditions. */
+class SpecializedConditionNode : public Object {
+ public:
+  /*!
+   * \brief List of conditions in conjunctive joint form (CNF).
+   *   Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc.,
+   *   where n, m are tvm::Var that represents a dimension in the tensor shape.
+   */
+  Array<PrimExpr> clauses;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("clauses", &clauses);
+  }
+
+  static constexpr const char* _type_key = "SpecializedCondition";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
+};
+
+/*!
+ * \brief Specialized condition to enable op specialization
+ */
+class SpecializedCondition : public ObjectRef {
+ public:
+  /*!
+   * \brief construct from conditions
+   * \param conditions The clauses in the specialized condition.
+   */
+  TVM_DLL SpecializedCondition(Array<PrimExpr> conditions);  // NOLINT(*)
+
+  /*!
+   * \brief Get the current specialized condition.
+   * \return the current specialized condition.
+   */
+  TVM_DLL static SpecializedCondition Current();
+
+  TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode);
+  class Internal;
+
+ private:
+  // enable with syntax.
+  friend class Internal;
+  friend class With<SpecializedCondition>;
+  /*! \brief Push a new specialized condition onto the thread local stack. */
+  TVM_DLL void EnterWithScope();
+  /*! \brief Pop a specialized condition off the thread local context stack. */
+  TVM_DLL void ExitWithScope();
+};
 
 // implementations
 inline const StageNode* Stage::operator->() const {
@@ -765,6 +813,7 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
 inline const IterVarAttrNode* IterVarAttr::operator->() const {
   return static_cast<const IterVarAttrNode*>(get());
 }
+
 }  // namespace te
 }  // namespace tvm
 #endif  // TVM_TE_SCHEDULE_H_
index cf8362a..eab4ddf 100644 (file)
@@ -41,8 +41,8 @@ from . import tophub
 from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
     LocalBuilder, LocalRunner, RPCRunner
 from .tuner import callback
-from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
-    register_topi_compute, register_topi_schedule, \
+from .task import get_config, create, ConfigSpace, ConfigEntity, \
+    register_topi_compute, register_topi_schedule, register_customized_task, \
     DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
     ApplyGraphBest as apply_graph_best
 from .env import GLOBAL_SCOPE
index 55d4180..963f7e5 100644 (file)
@@ -125,7 +125,7 @@ class RedisDatabase(Database):
         current = self.get(measure_str_key(inp))
         if current is not None:
             records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
-            results = [rec[1] for rec in records]
+            results = [rec[1] for rec in records if rec is not None]
             if get_all:
                 return results
             return max(results, key=lambda result: result.timestamp)
@@ -167,9 +167,12 @@ class RedisDatabase(Database):
             current = self.get(key)
             try:
                 records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
+                records = [rec for rec in records if rec is not None]
             except TypeError: # got a badly formatted/old format record
                 continue
 
+            if not records:
+                continue
             inps, results = zip(*records)
             inp = inps[0]
             if not func(inp, results):
index b7d1c44..4ff1139 100644 (file)
@@ -153,7 +153,10 @@ def get_flatten_name(fea):
         from .record import decode
         # flatten line to feature
         line = fea
-        inp, _ = decode(line)
+        ret = decode(line)
+        if ret is None:
+            raise ValueError("Unsupported AutoTVM log format")
+        inp, _ = ret
         target = _target.create(inp.target)
         with target:
             s, args = inp.template.instantiate(inp.config)
index b02c289..3e85e93 100644 (file)
@@ -25,7 +25,6 @@ import topi
 import tvm
 from tvm import autotvm, relay
 from tvm.autotvm.task import get_config
-from tvm.autotvm.task.topi_integration import deserialize_args, serialize_args
 from tvm.autotvm.record import encode, load_from_file
 from tvm.autotvm.measure import MeasureResult, MeasureInput
 
@@ -35,18 +34,16 @@ from .utils import is_boundary_node, get_in_nodes, get_out_nodes, has_multiple_i
 from ._base import INVALID_LAYOUT_TIME
 
 
-# Setup topi_op_name -> layout function
-# NOTE: To add more ops, change the following dictionary.
-OP2LAYOUT = {
-    "topi_nn_conv2d": topi.nn.conv2d_infer_layout,
-    "topi_nn_depthwise_conv2d_nchw": topi.nn.depthwise_conv2d_infer_layout,
-}
+def get_infer_layout(task_name):
+    if task_name.startswith("conv2d"):
+        return topi.nn.conv2d_infer_layout
+    if task_name.startswith("depthwise_conv2d"):
+        return topi.nn.depthwise_conv2d_infer_layout
+    raise ValueError("Cannot find infer layout for task %s" % task_name)
 
-
-@autotvm.template
+@autotvm.register_customized_task("layout_transform")
 def layout_transform(*args):
     """Autotvm layout transform template."""
-    args = deserialize_args(args)
     cfg = get_config()
     cfg.add_flop(-1)
     data = args[0]
@@ -82,7 +79,7 @@ class BaseGraphTuner(object):
                        Each row of this file is an encoded record pair.
             Otherwise, it is an iterator.
 
-        target_ops : List of str
+        target_ops : List of relay.op.Op
             Target tuning operators.
 
         target : str or tvm.target
@@ -104,7 +101,7 @@ class BaseGraphTuner(object):
         self._layout_transform_perf_records = {}
         self._layout_transform_interlayer_cost = {}
         self._input_shapes = input_shapes
-        self._target_ops = [op.__name__ for op in target_ops]
+        self._target_ops = target_ops
 
         self._name = name
         self._max_sch_num = max_sch_num
@@ -179,7 +176,7 @@ class BaseGraphTuner(object):
                         dtype = first_tensor[-1]
                         new_shape = tuple([val.value for val in node_entry["types"][0].shape])
                         actual_workload = (input_workload[0],) + \
-                                          ((new_shape + (dtype,)),) + input_workload[2:]
+                                          (("TENSOR", new_shape, dtype),) + input_workload[2:]
                         node_entry["workloads"].append(actual_workload)
                         if "record_candidates" not in node_entry:
                             node_entry["record_candidates"] = input_node["record_candidates"]
@@ -212,7 +209,7 @@ class BaseGraphTuner(object):
                 node_entry["record_candidates"] = cache_dict[workload]
                 continue
             record_candidates = []
-            infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
+            infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
             layout_tracking_dict = {}
             for record in cfg_dict[workload]:
                 in_measure, out_measure = record
@@ -264,7 +261,7 @@ class BaseGraphTuner(object):
 
                 if node_entry["op"] in self._target_ops:
                     o_idx = key
-                    o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
+                    o_infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
                     o_wkl = node_entry["workloads"][0]
                     i_topi_op = in_node_entry["topi_op"][0]
                     i_wkl = in_node_entry["workloads"][0]
@@ -273,14 +270,14 @@ class BaseGraphTuner(object):
                         pivot += 1
                         i_topi_op = in_node_entry["topi_op"][pivot]
                         i_wkl = in_node_entry["workloads"][pivot]
-                    i_infer_layout_func = OP2LAYOUT[i_topi_op]
+                    i_infer_layout_func = get_infer_layout(i_topi_op)
                 else:
                     o_idx = target_input_idx
                     if i <= target_input_pos:
                         continue
-                    o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
+                    o_infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
                     o_wkl = node_entry["workloads"][target_input_pos]
-                    i_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][i]]
+                    i_infer_layout_func = get_infer_layout(node_entry["topi_op"][i])
                     i_wkl = node_entry["workloads"][i]
 
                 if (i_idx, o_idx) in pair_tracker:
@@ -314,9 +311,8 @@ class BaseGraphTuner(object):
                                 to_sch_idx, args):
         """Create dictionary containing matrix format of layout transformation
         between nodes."""
-        sargs = serialize_args(args)
         in_layout, out_layout = args[1], args[2]
-        ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(sargs)
+        ltf_workload = autotvm.task.args_to_workload(args, 'layout_transform')
         idx_pair_key = (from_node_idx, to_node_idx)
 
         if in_layout == out_layout:
@@ -449,9 +445,8 @@ class BaseGraphTuner(object):
         measure_option = autotvm.measure_option(builder=builder, runner=runner)
         for args in args_list:
             data, in_layout, out_layout = args
-            args = serialize_args(args)
-            ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
-            if ltf_workload in  self._layout_transform_perf_records:
+            ltf_workload = autotvm.task.args_to_workload(args, 'layout_transform')
+            if ltf_workload in self._layout_transform_perf_records:
                 continue
 
             if infer_layout:
@@ -478,9 +473,8 @@ class BaseGraphTuner(object):
                 continue
 
             records = []
-            task = autotvm.task.create(layout_transform, args=args, target=self._target,
+            task = autotvm.task.create("layout_transform", args=args, target=self._target,
                                        target_host=target_host)
-            task.workload = ltf_workload
             tuner = autotvm.tuner.GridSearchTuner(task)
             tuner.tune(n_trial=1, measure_option=measure_option,
                        callbacks=[_log_to_list(records)])
index 7648322..17450ca 100644 (file)
@@ -18,8 +18,6 @@
 """API for graph traversing."""
 import threading
 
-import topi
-
 import tvm
 from tvm import relay, autotvm
 from tvm.relay import transform
@@ -30,13 +28,6 @@ from tvm.autotvm.task import TaskExtractEnv
 from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
 
 
-# Setup relay op base name -> topi compute functions
-# NOTE: To add more ops, change the following dictionary.
-OP2COMPUTE = {
-    "conv2d" : [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw],
-}
-
-
 def expr2graph(expr, target_ops, node_dict, node_list):
     """Convert relay expr to graph data structure
     and fetch workloads of target operators.
@@ -46,8 +37,8 @@ def expr2graph(expr, target_ops, node_dict, node_list):
     expr : tvm.relay.Expr.Function
         Input relay function expression.
 
-    target_ops: List of str
-        List of target relay base op name
+    target_ops: List of relay.op.Op
+        List of target relay ops
 
     node_dict : dictionary from tvm.relay.Expr to int
         Dictionary to record node index
@@ -58,14 +49,11 @@ def expr2graph(expr, target_ops, node_dict, node_list):
         {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type],
          "name": str, "workloads": [tuple], "topi_op": [function]}
     """
+    # TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact
+    #   that # autotvm tasks == # ops. But this won't be true after having relay op
+    #   strategy. We need to find a solution to fix this.
     env = TaskExtractEnv.get(allow_duplicate=True)
-    topi_funcs = []
-    for op_name in target_ops:
-        if op_name not in OP2COMPUTE:
-            raise RuntimeError("Not supported relay op in graph tuner: %s"
-                               % op_name)
-        topi_funcs += OP2COMPUTE[op_name]
-    env.reset(topi_funcs)
+    env.reset(target_ops)
     # pylint: disable=not-context-manager
     with env:
         _expr2graph_impl(expr, target_ops, node_dict, node_list)
@@ -75,8 +63,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
                 task_name, args = env.task_collection[task_pos]
                 task = autotvm.task.create(task_name, args,
                                            target="llvm",
-                                           target_host=None,
-                                           template_key='direct')
+                                           target_host=None)
                 node_entry["workloads"] = [task.workload]
                 node_entry["topi_op"] = [task_name]
                 task_pos += 1
@@ -98,11 +85,11 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
             return
         node_index = len(node_list)
         node_entry = {"node": node, "inputs": [], "types": [],
-                      "op": "null", "name": None}
+                      "op": None, "name": None}
 
         if isinstance(node, Call):
-            op_name = node.op.name.split(".")[-1]
-            node_entry["op"] = op_name
+            op = node.op
+            node_entry["op"] = node.op
             for arg in node.args:
                 in_node_idx = node_dict[arg]
                 if isinstance(arg, (Tuple, TupleGetItem)):
@@ -118,12 +105,12 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
                     node_entry["types"].append(tupe_type)
             else:
                 raise RuntimeError("Unsupported output type %s in operator %s"
-                                   % (type(out_type), op_name))
+                                   % (type(out_type), op.name))
 
             # Utilize tracing target to fetch workload with topo-order.
             # Since we only need workload, dummy target can be used to
             # create task.
-            if op_name in target_ops:
+            if op in target_ops:
                 params = []
                 for i, input_idx in enumerate(node_entry["inputs"]):
                     input_node_entry = node_list[input_idx[0]]
@@ -133,7 +120,7 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
                                            "operators with input node of type "
                                            "relay.expr.Var/Constant/Call. Now "
                                            "find a target op %s with input type %s"
-                                           % (op_name, str(type(input_node_entry["node"]))))
+                                           % (op, str(type(input_node_entry["node"]))))
                     free_var = relay.Var("var_%d" % i, input_type)
                     params.append(free_var)
                 call = relay.Call(node.op, params, node.attrs)
@@ -155,11 +142,9 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
                 _expr2graph_impl(node, target_ops, node_dict, node_list)
             return
         elif isinstance(node, TupleGetItem):
-            node_entry["op"] = "TupleGetItem"
             in_node_idx = node_dict[node.tuple_value]
             node_entry["inputs"].append([in_node_idx, node.index, 0])
         elif isinstance(node, Tuple):
-            node_entry["op"] = "Tuple"
             for tuple_item in node:
                 in_node_idx = node_dict[tuple_item]
                 if isinstance(tuple_item, TupleGetItem):
index 137ccbe..2486d0c 100644 (file)
@@ -47,7 +47,7 @@ def has_multiple_inputs(node_list, node_idx, input_names):
         in_idx = in_idx[0]
         in_node = node_list[in_idx]
         # Exclude parameter nodes
-        if in_node["op"] != "null" or \
+        if in_node["op"] is not None or \
                 ("name" in in_node and in_node["name"] in input_names):
             num_inputs += 1
     return num_inputs > 1
@@ -72,9 +72,10 @@ def is_boundary_node(node_entry, input_names):
         whether node is a boundary node.
     """
     # Operators dependent on original layouts.
-    _LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
-                        "multibox_prior", "multibox_transform_loc", "where",
-                        "non_max_suppression", "strided_slice"]
+    _LAYOUT_FIXED_OP = [relay.op.get(name) for name in (
+        "nn.batch_flatten", "transpose", "reshape", "vision.multibox_prior",
+        "vision.multibox_transform_loc", "where", "vision.non_max_suppression",
+        "strided_slice")]
 
     out = node_entry["op"] in _LAYOUT_FIXED_OP or \
           ("name" in node_entry and node_entry["name"] in input_names)
@@ -95,9 +96,7 @@ def is_skipped_node(node_entry):
         whether node is skipped.
     """
     # Operators not counted in graph tuner.
-    _SKIPPED_OP = ["Tuple"]
-
-    return node_entry["op"] in _SKIPPED_OP
+    return isinstance(node_entry["node"], relay.Tuple)
 
 
 def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
index fbf4a08..416b2cd 100644 (file)
@@ -28,14 +28,16 @@ import time
 import os
 import itertools
 from collections import OrderedDict
+import numpy as np
 
 from .. import build, lower, target as _target
-
+from .. import __version__
 from . import task
 from .task import ConfigEntity, ApplyHistoryBest
 from .measure import MeasureInput, MeasureResult
 
-AUTOTVM_LOG_VERSION = 0.1
+AUTOTVM_LOG_VERSION = 0.2
+_old_version_warning = True
 logger = logging.getLogger('autotvm')
 
 try:  # convert unicode to str for python2
@@ -88,27 +90,30 @@ def encode(inp, result, protocol='json'):
 
     if protocol == 'json':
         json_dict = {
-            "i": (str(inp.target),
-                  inp.task.name, inp.task.args, inp.task.kwargs,
-                  inp.task.workload,
-                  inp.config.to_json_dict()),
+            "input": (str(inp.target),
+                      inp.task.name, inp.task.args, inp.task.kwargs),
+
+            "config": inp.config.to_json_dict(),
+
+            "result": (result.costs if result.error_no == 0 else (1e9,),
+                       result.error_no,
+                       result.all_cost,
+                       result.timestamp),
 
-            "r": (result.costs if result.error_no == 0 else (1e9,),
-                  result.error_no,
-                  result.all_cost,
-                  result.timestamp),
+            "version": AUTOTVM_LOG_VERSION,
 
-            "v": AUTOTVM_LOG_VERSION
+            "tvm_version": __version__
         }
         return json.dumps(json_dict)
     if protocol == 'pickle':
         row = (str(inp.target),
                str(base64.b64encode(pickle.dumps([inp.task.name,
                                                   inp.task.args,
-                                                  inp.task.kwargs,
-                                                  inp.task.workload])).decode()),
+                                                  inp.task.kwargs])).decode()),
                str(base64.b64encode(pickle.dumps(inp.config)).decode()),
-               str(base64.b64encode(pickle.dumps(tuple(result))).decode()))
+               str(base64.b64encode(pickle.dumps(tuple(result))).decode()),
+               str(AUTOTVM_LOG_VERSION),
+               str(__version__))
         return '\t'.join(row)
 
     raise RuntimeError("Invalid log protocol: " + protocol)
@@ -119,20 +124,29 @@ def decode(row, protocol='json'):
 
     Parameters
     ----------
-    row: str
+    row : str
         a row in the logger file
-    protocol: str
+
+    protocol : str
         log protocol, json or pickle
 
     Returns
     -------
-    input: autotvm.tuner.MeasureInput
-    result: autotvm.tuner.MeasureResult
+    ret : tuple(autotvm.tuner.MeasureInput, autotvm.tuner.MeasureResult), or None
+        The tuple of input and result, or None if input uses old version log format.
     """
     # pylint: disable=unused-variable
+    global _old_version_warning
+
     if protocol == 'json':
         row = json.loads(row)
-        tgt, task_name, task_args, task_kwargs, workload, config = row['i']
+        if 'v' in row and row['v'] == 0.1:
+            if _old_version_warning:
+                logger.warning("AutoTVM log version 0.1 is no longer supported.")
+                _old_version_warning = False
+            return None
+
+        tgt, task_name, task_args, task_kwargs = row["input"]
         tgt = _target.create(str(tgt))
 
         def clean_json_to_python(x):
@@ -148,22 +162,27 @@ def decode(row, protocol='json'):
             return x
 
         tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
-        tsk.workload = clean_json_to_python(workload)
-        config = ConfigEntity.from_json_dict(config)
+        config = ConfigEntity.from_json_dict(row["config"])
         inp = MeasureInput(tgt, tsk, config)
-        result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["r"]])
+        result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["result"]])
+        config.cost = np.mean(result.costs)
 
         return inp, result
     if protocol == 'pickle':
         items = row.split("\t")
+        if len(items) == 4:
+            if _old_version_warning:
+                logger.warning("AutoTVM log version 0.1 is no longer supported.")
+                _old_version_warning = False
+            return None
         tgt = _target.create(items[0])
         task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
         config = pickle.loads(base64.b64decode(items[2].encode()))
-        result = pickle.loads(base64.b64decode(items[3].encode()))
+        result = MeasureResult(*pickle.loads(base64.b64decode(items[3].encode())))
+        config.cost = np.mean(result.costs)
 
         tsk = task.Task(task_tuple[0], task_tuple[1])
-        tsk.workload = task_tuple[3]
-        return MeasureInput(tgt, tsk, config), MeasureResult(*result)
+        return MeasureInput(tgt, tsk, config), result
 
     raise RuntimeError("Invalid log protocol: " + protocol)
 
@@ -183,7 +202,10 @@ def load_from_file(filename):
     """
     for row in open(filename):
         if row and not row.startswith('#'):
-            inp, res = decode(row)
+            ret = decode(row)
+            if ret is None:
+                continue
+            inp, res = ret
             # Avoid loading the record with an empty config. The TOPI schedule with no entities
             # will result in an empty entity map (e.g., depthwise_conv2d_nchw on x86).
             # Using an empty config will cause problems when applying alter op like NCHW to NCHWc.
@@ -208,7 +230,7 @@ def split_workload(in_file, clean=True):
 
     logger.info("start converting...")
     pool = multiprocessing.Pool()
-    lines = pool.map(decode, lines)
+    lines = [rec for rec in pool.map(decode, lines) if rec is not None]
     logger.info("map done %.2f", time.time() - tic)
 
     wkl_dict = OrderedDict()
index f249f6b..29313d4 100644 (file)
@@ -22,12 +22,13 @@ This module defines the task data structure, as well as a collection(zoo)
 of typical tasks of interest.
 """
 
-from .task import Task, create, register, template, get_config, args_to_workload
+from .task import Task, create, get_config, args_to_workload, \
+    register_customized_task
 from .space import ConfigSpace, ConfigEntity
 from .code_hash import attach_code_hash, attach_code_hash_to_arg
-from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \
+from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
     FallbackContext, clear_fallback_cache, ApplyGraphBest
 
 from .topi_integration import register_topi_compute, register_topi_schedule, \
-    TaskExtractEnv
+    TaskExtractEnv, get_workload
 from .relay_integration import extract_from_program, extract_from_multiple_program
index e7022fa..97ee538 100644 (file)
@@ -33,9 +33,6 @@ from __future__ import absolute_import as _abs
 import logging
 
 import numpy as np
-from decorator import decorate
-
-from tvm import target as _target
 
 from .space import FallbackConfigEntity
 
@@ -152,79 +149,6 @@ class DispatchContext(object):
         DispatchContext.current = self._old_ctx
 
 
-def dispatcher(fworkload):
-    """Wrap a workload dispatcher function.
-
-    Parameters
-    ----------
-    fworkload : function
-        The workload extraction function from arguments.
-
-    Returns
-    -------
-    fdispatcher : function
-        A wrapped dispatcher function, which will
-        dispatch based on DispatchContext and
-        the current workload.
-    """
-    dispatch_dict = {}
-    func_name = fworkload.__name__
-
-    def register(key, func=None, override=False):
-        """Register template function.
-
-        Parameters
-        ----------
-        key : str or List of str
-            The template key to identify the template
-            under this dispatcher.
-        func : function
-            The function to be registered.
-            The first argument of the function is always
-            cfg returned by DispatchContext,
-            the rest arguments are the same as the fworkload.
-        override : bool
-            Whether override existing registration.
-
-        Returns
-        -------
-        The register function if necessary.
-        """
-        if isinstance(key, str):
-            key = [key]
-
-        def _do_reg(myf):
-            for x in key:
-                if x in dispatch_dict and not override:
-                    raise ValueError(
-                        "Key %s is already registered for %s" % (x, func_name))
-                dispatch_dict[x] = myf
-            return myf
-
-        if func:
-            return _do_reg(func)
-        return _do_reg
-
-    def dispatch_func(func, *args, **kwargs):
-        """The wrapped dispatch function"""
-        tgt = _target.Target.current()
-        workload = func(*args, **kwargs)
-        cfg = DispatchContext.current.query(tgt, workload)
-        if cfg.is_fallback and not cfg.template_key:
-            # first try 'direct' template
-            if 'direct' in dispatch_dict:
-                return dispatch_dict['direct'](cfg, *args, **kwargs)
-            # otherwise pick a random template
-            for v in dispatch_dict.values():
-                return v(cfg, *args, **kwargs)
-        else:
-            return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
-
-    fdecorate = decorate(fworkload, dispatch_func)
-    fdecorate.register = register
-    return fdecorate
-
-
 class ApplyConfig(DispatchContext):
     """Apply a deterministic config entity for all queries.
 
@@ -334,7 +258,8 @@ class ApplyHistoryBest(DispatchContext):
         if key in self._best_user_defined:
             return self._best_user_defined[key]
         if key in self.best_by_model:
-            return self.best_by_model[key][0].config
+            inp, _ = self.best_by_model[key]
+            return inp.config
 
         # then try matching by target key
         for k in target.keys:
@@ -342,13 +267,16 @@ class ApplyHistoryBest(DispatchContext):
             if key in self._best_user_defined:
                 return self._best_user_defined[key]
             if key in self.best_by_targetkey:
-                return self.best_by_targetkey[key][0].config
+                inp, _ = self.best_by_targetkey[key]
+                return inp.config
 
         return None
 
     def update(self, target, workload, cfg):
         model = target.model
         key = (model, workload)
+        # assume user provided config is the best
+        cfg.cost = 0
         self._best_user_defined[key] = cfg
 
         for k in target.keys:
@@ -481,8 +409,12 @@ class ApplyGraphBest(DispatchContext):
         """
         if self._counter < len(self._records):
             cfg = self._records[self._counter][0].config
+            wkl = self._records[self._counter][0].task.workload
+            if workload is not None:
+                assert wkl == workload
             self._counter += 1
-            self.update(target, workload, cfg)
+            self.update(target, wkl, cfg)
+            cfg.workload = wkl
             return cfg
         key = (str(target), workload)
         if key not in self._global_cfg_dict:
index b39c8d4..cd8d32f 100644 (file)
@@ -21,10 +21,9 @@ Decorator and utilities for the integration with TOPI and Relay
 
 """
 import threading
-import warnings
 import logging
 
-
+import tvm
 from .task import create
 from .topi_integration import TaskExtractEnv
 
@@ -55,8 +54,7 @@ def _lower(mod,
     compiler.lower(mod, target=target)
 
 
-def extract_from_program(mod, params, ops, target, target_host=None,
-                         template_keys=None):
+def extract_from_program(mod, params, target, target_host=None, ops=None):
     """ Extract tuning tasks from a relay program.
 
     This function is the single program version of extract_from_multiple_program.
@@ -67,27 +65,22 @@ def extract_from_program(mod, params, ops, target, target_host=None,
         The module or function to tune
     params: dict of str to numpy array
         The associated parameters of the program
-    ops: List of relay op
-        List of relay ops to be tuned
     target: tvm.target.Target
         The compilation target
     target_host: tvm.target.Target
         The host compilation target
-    template_keys: dict of topi op to str
-        The tuning template keys map for schedules, default to None.
-        Example: {topi.nn.conv2d: 'direct'}
+    ops: List[relay.op.Op] or None
+        List of relay ops to be tuned. If not specified, all tunable ops will be extracted.
 
     Returns
     -------
     task: Array of autotvm.task.Task
         collected tasks
     """
-    return extract_from_multiple_program([mod], [params], ops, target, target_host,
-                                         template_keys)
+    return extract_from_multiple_program([mod], [params], target, target_host, ops)
 
 
-def extract_from_multiple_program(mods, params, ops, target, target_host=None,
-                                  template_keys=None):
+def extract_from_multiple_program(mods, params, target, target_host=None, ops=None):
     """ Extract tuning tasks from multiple relay programs.
 
     This function collects tuning tasks by building a list of programs
@@ -99,15 +92,12 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
         The list of modules or functions to tune
     params: List of dict of str to numpy array
         The associated parameters of the programs
-    ops: List of relay op
-        List of relay ops to be tuned
     target: tvm.target.Target
         The compilation target
     target_host: tvm.target.Target
         The host compilation target
-    template_keys: dict of topi op to str
-        The tuning template keys map for schedules, default to None.
-        Example: {topi.nn.conv2d: 'direct'}
+    ops: List[relay.op.Op] or None
+        List of relay ops to be tuned.  If not specified, all tunable ops will be extracted.
 
     Returns
     -------
@@ -115,36 +105,13 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
         collected tasks
     """
     # pylint: disable=import-outside-toplevel
-    import tvm.relay.op
     from tvm import relay
     import topi
 
     env = TaskExtractEnv.get()
 
-    # NOTE: To add more ops, you only need to change the following lists
-    # relay op -> topi compute
-    OP2TOPI = {
-        tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
-                                 topi.nn.group_conv2d_nchw,
-                                 topi.nn.conv2d_NCHWc,
-                                 topi.nn.conv2d_NCHWc_int8],
-        tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
-        tvm.relay.op.nn.dense: [topi.nn.dense],
-        tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
-        tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
-        tvm.relay.op.nn.conv1d_transpose: [topi.nn.conv1d_transpose_ncw],
-        tvm.relay.op.nn.conv3d: [topi.nn.conv3d],
-    }
-
-    topi_funcs = []
-    for op_name in ops:
-        if op_name in OP2TOPI:
-            topi_funcs.extend(OP2TOPI[op_name])
-        else:
-            warnings.warn("Op %s is not tunable, ignored" % op_name)
-
     # run compiler to collect all TOPI calls during compilation
-    env.reset(topi_funcs)
+    env.reset(ops)
     with env:
         # disable logger temporarily
         old_state = logger.disabled
@@ -164,24 +131,12 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
 
         logger.disabled = old_state
 
-    # convert *topi op to template key* map to *task name to template key* map
-    task_name_to_keys = {}
-    if template_keys is not None:
-        for op in template_keys.keys():
-            if op in env.topi_to_task:
-                task_name_to_keys[env.topi_to_task[op]] = template_keys[op]
-            else:
-                logger.warning("Invalid template key, fallback to direct")
-                task_name_to_keys[env.topi_to_task[op]] = 'direct'
-
     # create tasks for target
     tasks = []
     for task_name, args in env.get_tasks():
         try:
-            key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
             tsk = create(task_name, args,
-                         target=target, target_host=target_host,
-                         template_key=key)
+                         target=target, target_host=target_host)
             tasks.append(tsk)
         except topi.InvalidShapeError:
             logger.warning("Invalid shape during AutoTVM task creation")
index fbdd34e..47c2270 100644 (file)
@@ -613,9 +613,9 @@ class ConfigSpace(object):
         self._entity_map = OrderedDict()  # name -> entity
         self._constraints = []
         self.errors = []
-        self.template_key = None
         self.code_hash = None
         self.flop = 0
+        self.cost = None
         self.is_fallback = False
 
     @staticmethod
@@ -796,7 +796,7 @@ class ConfigSpace(object):
         for name, space in self.space_map.items():
             entities[name] = space[t % len(space)]
             t //= len(space)
-        ret = ConfigEntity(index, self.code_hash, self.template_key, entities, self._constraints)
+        ret = ConfigEntity(index, self.code_hash, entities, self._constraints)
         return ret
 
     def __iter__(self):
@@ -836,17 +836,14 @@ class ConfigEntity(ConfigSpace):
         index of this config in space
     code_hash: str
         hash of schedule code
-    template_key : str
-        The specific template key
     entity_map: dict
         map name to transform entity
     constraints : list
         List of constraints
     """
-    def __init__(self, index, code_hash, template_key, entity_map, constraints):
+    def __init__(self, index, code_hash, entity_map, constraints):
         super(ConfigEntity, self).__init__()
         self.index = index
-        self.template_key = template_key
         self._collect = False
         self._entity_map = entity_map
         self._space_map = None
@@ -896,9 +893,8 @@ class ConfigEntity(ConfigSpace):
             a json serializable dictionary
         """
         ret = {}
-        ret['i'] = int(self.index)
-        ret['t'] = self.template_key
-        ret['c'] = self.code_hash
+        ret['index'] = int(self.index)
+        ret['code_hash'] = self.code_hash
         entity_map = []
         for k, v in self._entity_map.items():
             if isinstance(v, SplitEntity):
@@ -911,7 +907,7 @@ class ConfigEntity(ConfigSpace):
                 entity_map.append((k, 'ot', v.val))
             else:
                 raise RuntimeError("Invalid entity instance: " + v)
-        ret['e'] = entity_map
+        ret['entity'] = entity_map
         return ret
 
     @staticmethod
@@ -930,13 +926,12 @@ class ConfigEntity(ConfigSpace):
             The corresponding config object
 
         """
-        index = json_dict["i"]
-        code_hash = json_dict["c"]
-        template_key = json_dict["t"]
+        index = json_dict["index"]
+        code_hash = json_dict["code_hash"]
         constraints = []
         entity_map = OrderedDict()
 
-        for item in json_dict["e"]:
+        for item in json_dict["entity"]:
             key, knob_type, knob_args = item
             if knob_type == 'sp':
                 entity = SplitEntity(knob_args)
@@ -950,11 +945,10 @@ class ConfigEntity(ConfigSpace):
                 raise RuntimeError("Invalid config knob type: " + knob_type)
             entity_map[str(key)] = entity
 
-        return ConfigEntity(index, code_hash, template_key, entity_map, constraints)
+        return ConfigEntity(index, code_hash, entity_map, constraints)
 
     def __repr__(self):
-        return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key,
-                                self.code_hash, self.index)
+        return "%s,%s,%d" % (str(self._entity_map)[12:-1], self.code_hash, self.index)
 
 
 class FallbackConfigEntity(ConfigSpace):
@@ -1068,4 +1062,4 @@ class FallbackConfigEntity(ConfigSpace):
         self._entity_map[name] = entity
 
     def __repr__(self):
-        return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
+        return "%s,%s" % (str(self._entity_map)[12:-1], self.code_hash)
index 9ff8b24..ca1ae0e 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=unused-variable
+# pylint: disable=unused-variable,not-callable
 """Definition of task function.
 
 Task can be constructed from tuple of func, args, and kwargs.
@@ -24,10 +24,10 @@ registers the standard task.
 
 import numpy as np
 
-from ... import tensor, expr, container, target as _target
+from ... import tensor, expr, container, placeholder, target as _target
 
-from ..util import get_const_int, get_const_tuple, get_func_name
-from .dispatcher import DispatchContext, ApplyConfig, dispatcher
+from ..util import get_const_int, get_const_tuple
+from .dispatcher import DispatchContext, ApplyConfig
 from .space import ConfigSpace
 
 def _raise_error(*args, **kwargs):  # pylint: disable=unused-argument
@@ -35,6 +35,70 @@ def _raise_error(*args, **kwargs):  # pylint: disable=unused-argument
                        "of this task is registered in another python file "
                        "which is not imported in this run")
 
+
+def serialize_args(args):
+    """serialize arguments of a topi function to a hashable tuple.
+
+    Parameters
+    ----------
+    args: list of hashable or Tensor
+    """
+    def _encode(x):
+        if isinstance(x, tensor.Tensor):
+            return ('TENSOR', get_const_tuple(x.shape), x.dtype)
+        if isinstance(x, (tuple, list, container.Array)):
+            return tuple([_encode(a) for a in x])
+        if isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
+            return x
+        if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
+            return x.value
+        if x is None:
+            return None
+        raise RuntimeError('Do not support type "%s" in argument. Consider to use'
+                           'primitive types or tvm.tir.Var only' % type(x))
+    ret = []
+    for t in args:
+        ret.append(_encode(t))
+    return tuple(ret)
+
+
+def deserialize_args(args):
+    """The inverse function of :code:`serialize_args`.
+
+    Parameters
+    ----------
+    args: list of hashable or Tensor
+    """
+    ret = []
+    for t in args:
+        if isinstance(t, tuple) and t[0] == 'TENSOR':
+            ret.append(placeholder(shape=t[1], dtype=t[2]))
+        else:
+            ret.append(t)
+    return ret
+
+
+def args_to_workload(args, task_name=None):
+    """Convert argument list to hashable workload tuple.
+    This function will convert list to tuple, tvm node to python value and
+    flatten tvm.tensor.Tensor to a tuple
+
+    Parameters
+    ----------
+    task_name : str
+        The AutoTVM task name
+
+    args : list of args
+        The arguments to the function
+
+    Returns
+    -------
+    ret: hashable
+        The hashable value
+    """
+    return (task_name,) + serialize_args(args) if task_name is not None else serialize_args(args)
+
+
 class Task(object):
     """A Tunable Task
 
@@ -55,11 +119,14 @@ class Task(object):
         self.func = TASK_TABLE.get(name, _raise_error)
 
         # auxiliary info, available after `init_space` is called
-        self.workload = None
         self.flop = None
         self.target = None
         self.target_host = None
 
+    @property
+    def workload(self):
+        return (self.name,) + serialize_args(self.args)
+
     def instantiate(self, config):
         """Instantiate this task function (template) with a config.
         Returns corresponding schedule.
@@ -94,7 +161,6 @@ class Task(object):
             "args": self.args,
             "kwargs": self.kwargs,
             "config_space": self.config_space,
-            "workload": self.workload,
             "flop": self.flop,
             "target": self.target,
             "target_host": self.target_host
@@ -106,7 +172,6 @@ class Task(object):
         self.kwargs = state["kwargs"]
         self.config_space = state["config_space"]
         self.func = TASK_TABLE.get(state["name"], _raise_error)
-        self.workload = state["workload"]
         self.flop = state["flop"]
         self.target = state["target"]
         self.target_host = state["target_host"]
@@ -116,135 +181,119 @@ class Task(object):
             self.name, self.args, self.kwargs, self.workload
         )
 
-TASK_TABLE = {
-}
+TASK_TABLE = {}
+
+class TopiTemplate(object):
+    """Topi template that holds the topi compute and schedule function"""
+    def __init__(self):
+        self.compute = None
+        self.schedule = None
+        self.customized_func = None
+
+    def __call__(self, *args, **kwargs):
+        args = deserialize_args(args)
+        if self.customized_func is None:
+            return self._default_func(*args, **kwargs)
+        assert callable(self.customized_func)
+        return self.customized_func(*args, **kwargs)
+
+    def _default_func(self, *args, **kwargs):
+        assert callable(self.compute) and callable(self.schedule)
+        out = self.compute(*args, **kwargs)
+        arg_bufs = [out] + self.get_inputs(out)
+        s = self.schedule([out])
+        return s, arg_bufs
+
+    def get_inputs(self, out):
+        inputs = []
+        queue = [out]
+        while queue:
+            t = queue.pop(0)
+            if isinstance(t.op, tensor.PlaceholderOp):
+                inputs.append(t)
+            else:
+                queue.extend(t.op.input_tensors)
+        return inputs
 
-def register(name, func=None, override=False):
-    """Register a task function.
+def register_task_compute(name, func=None):
+    """Register compute function to autotvm task
 
     Parameters
     ----------
-    name : str
-        The name to identify the task.
-    func : callable
-        The function to be registered.
-    override : bool
-        Whether override existing registration.
+    name: str
+        The task name
+
+    func: None or callable
+        If it is None, return a decorator.
+        If is callable, decorate this function.
 
     Returns
     -------
-    func: callable
-        The registered function
+    decorator: callable
+        A decorator
     """
-    def _do_reg(myf):
-        if name in TASK_TABLE and not override:
-            raise ValueError(
-                "Key %s is already registered" % name)
-        TASK_TABLE[name] = myf
-        return myf
+    def _do_reg(f):
+        if name not in TASK_TABLE:
+            TASK_TABLE[name] = TopiTemplate()
+        tmpl = TASK_TABLE[name]
+        if tmpl.compute is not None:
+            raise ValueError("Compute is already registered in autoTVM task %s" % name)
+        tmpl.compute = f
+        return f
     if func:
         return _do_reg(func)
     return _do_reg
 
-def create(func_name, args, target, target_host=None, template_key=None):
-    """Create a tuning task and initialize its search space
+def register_task_schedule(name, func=None):
+    """Register schedule function to autotvm task
 
     Parameters
     ----------
-    func_name : str or callable
-        The task function
-    args : List
-        Positional arguments
-    target : Target
-        The compilation target
-    target_host: Target, optional
-        The compilation target for host side
+    name: str
+        The task name
+
+    func: None or callable
+        If it is None, return a decorator.
+        If is callable, decorate this function.
 
     Returns
     -------
-    tsk: Task
-        a task object
+    decorator: callable
+        A decorator
     """
-    if callable(func_name):
-        # register this function if it is not registered before
-        func = func_name
-        func_name = func.func_name if hasattr(func, 'func_name') else func.__name__
-        if func_name in TASK_TABLE:
-            assert func == TASK_TABLE[func_name], "Find name conflict in task registration. " \
-                                                  "Consider to choose another name for this task"
-        else:
-            register(func_name, func=func)
-
-    func = TASK_TABLE[func_name]
-    ret = Task(func_name, args)
-
-    if isinstance(target, str):
-        target = _target.create(target)
-
-    # init config space
-    ret.config_space = ConfigSpace()
-    ret.config_space.template_key = template_key or ""
-
-    ctx = ApplyConfig(ret.config_space)
-    with ctx:
-        with target:
-            sch, _ = func(*args)
-            ret.config_space.code_hash = getattr(sch, 'code_hash', None)
+    def _do_reg(f):
+        if name not in TASK_TABLE:
+            TASK_TABLE[name] = TopiTemplate()
+        tmpl = TASK_TABLE[name]
+        if tmpl.schedule is not None:
+            raise ValueError("Schedule is already registered in autoTVM task %s" % name)
+        tmpl.schedule = f
+        return f
+    if func:
+        return _do_reg(func)
+    return _do_reg
 
-    ret.workload = ctx.workload
-    ret.flop = ret.config_space.flop or compute_flop(sch)
-    ret.target = target
-    ret.target_host = target_host
+def register_customized_task(name, func=None):
+    """Register a customized function to AutoTVM task.
 
-    return ret
-
-def args_to_workload(x, topi_compute_func=None):
-    """Convert argument list to hashable workload tuple.
-    This function will convert list to tuple, tvm node to python value and
-    flatten tvm.tensor.Tensor to a tuple
+    In most cases, you can just use register_topi_compute and register_topi_schedule
+    with the same task name to define an AutoTVM task. However, you can also
+    create a customized AutoTVM task that defines a tunable template or performs
+    extra layout transform before invoking compute/schedule function.
 
     Parameters
     ----------
-    x: primitive hashable types or tensor.Tensor
-        The original value
-    topi_compute_func: topi compute function
-        The function name will be added as first element of the workload tuple
-
-    Returns
-    -------
-    ret: hashable
-        The hashable value
-    """
-    if isinstance(x, tensor.Tensor):
-        workload = get_const_tuple(x.shape) + (x.dtype, )
-    elif isinstance(x, (tuple, list, container.Array)):
-        workload = tuple([args_to_workload(a) for a in x])
-    elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
-        workload = x
-    elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
-        workload = x.value
-    elif x is None:
-        workload = 0
-    else:
-        raise RuntimeError('Do not support type "%s" in argument. Consider to use'
-                           'primitive types or tvm.tir.Var only' % type(x))
-    return (get_func_name(topi_compute_func), ) + workload  if topi_compute_func else workload
-
-def template(func):
-    """
-    Decorate a function as a tunable schedule template
+    name: str
+        The task name
 
-    Parameters
-    ----------
-    func: callable
-        A callable template function.
-        Its argument should be hashable values.
-        Its return value should be a Tuple(Schedule, Array of Tensor)
+    func: None or callable
+        If it is None, return a decorator.
+        If is callable, decorate this function.
 
     Returns
     -------
-    func: callable
-        The decorated function
+    decorator: callable
+        A decorator
 
     Examples
     --------
@@ -252,7 +301,7 @@ def template(func):
 
     .. code-block:: python
 
-        @autotvm.template
+        @autotvm.register_customized_task("matmul")
         def matmul(N, L, M, dtype):
             A = tvm.placeholder((N, L), name='A', dtype=dtype)
             B = tvm.placeholder((L, M), name='B', dtype=dtype)
@@ -279,24 +328,57 @@ def template(func):
 
             return s, [A, B, C]
     """
-    # pylint: disable=unused-variable
+    def _do_reg(f):
+        if name not in TASK_TABLE:
+            TASK_TABLE[name] = TopiTemplate()
+        tmpl = TASK_TABLE[name]
+        if tmpl.customized_func is not None:
+            raise ValueError("Customized func is already registered in autoTVM task %s" % name)
+        tmpl.customized_func = f
+        return f
+    if func:
+        return _do_reg(func)
+    return _do_reg
+
+def create(task_name, args, target, target_host=None):
+    """Create a tuning task and initialize its search space
+
+    Parameters
+    ----------
+    task_name : str
+        The AutoTVM task name
+    args : List
+        Positional arguments
+    target : Target
+        The compilation target
+    target_host: Target, optional
+        The compilation target for host side
+
+    Returns
+    -------
+    tsk: Task
+        a task object
+    """
+    args = serialize_args(args)
+    ret = Task(task_name, args)
 
-    fname = get_func_name(func)
+    if isinstance(target, str):
+        target = _target.create(target)
 
-    @register(fname)
-    @dispatcher
-    def config_dispatcher(*args, **kwargs):
-        assert not kwargs, "Do not support kwargs in template function call"
-        return (fname, ) + args_to_workload(args)
+    # init config space
+    ret.config_space = ConfigSpace()
+
+    ctx = ApplyConfig(ret.config_space)
+    with ctx:
+        with target:
+            sch, _ = ret.func(*args)
+            ret.config_space.code_hash = getattr(sch, 'code_hash', None)
 
-    @config_dispatcher.register("")
-    def template_call(cfg, *args, **kwargs):
-        assert not kwargs, "Do not support kwargs in template function call"
-        with ApplyConfig(cfg):
-            return func(*args, **kwargs)
+    ret.flop = ret.config_space.flop or compute_flop(sch)
+    ret.target = target
+    ret.target_host = target_host
 
-    config_dispatcher.func_name = fname
-    return config_dispatcher
+    return ret
 
 def get_config():
     """Get current config object
@@ -306,7 +388,8 @@ def get_config():
     cfg: ConfigSpace or ConfigEntity
         The current config
     """
-    return DispatchContext.current.query(None, None)
+    tgt = _target.Target.current(allow_none=True)
+    return DispatchContext.current.query(tgt, None)
 
 class FlopCalculationError(RuntimeError):
     """Error happens when estimating FLOP for a compute op"""
index 3d3a1d3..45385fb 100644 (file)
@@ -27,47 +27,11 @@ tuple.
 See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
 """
 import tvm.te._ffi_api
+from tvm import target as _target
 
-from ... import tensor, placeholder
-
-from .task import args_to_workload, dispatcher, register
-from ..util import get_const_tuple
-
-# A table that records all registered dispatcher for all targets
-_REGISTERED_DISPATCHER = {
-}
-
-
-def serialize_args(args):
-    """serialize arguments of a topi function to a hashable tuple.
-
-    Parameters
-    ----------
-    args: list of hashable or Tensor
-    """
-    ret = []
-    for t in args:
-        if isinstance(t, tensor.Tensor):
-            ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
-        else:
-            ret.append(t)
-    return tuple(ret)
-
-
-def deserialize_args(args):
-    """The inverse function of :code:`serialize_args`.
-
-    Parameters
-    ----------
-    args: list of hashable or Tensor
-    """
-    ret = []
-    for t in args:
-        if isinstance(t, tuple) and t[0] == 'TENSOR':
-            ret.append(placeholder(shape=t[1], dtype=t[2]))
-        else:
-            ret.append(t)
-    return ret
+from ... import tensor
+from .task import args_to_workload, DispatchContext, \
+    register_task_compute, register_task_schedule, serialize_args
 
 
 # Task extractor for relay program
@@ -77,250 +41,46 @@ class TaskExtractEnv:
     registered = None
 
     def __init__(self, allow_duplicate=False):
-        # pylint: disable=import-outside-toplevel
-        import topi
-
-        # topi compute -> autotvm task name
-        self.topi_to_task = {
-            topi.nn.conv2d: "topi_nn_conv2d",
-            topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
-            topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
-            topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
-            topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
-            topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8",
-            topi.nn.dense: "topi_nn_dense",
-            topi.nn.batch_matmul: "topi_nn_batch_matmul",
-            topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
-            topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
-            topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
-            topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
-            topi.nn.conv1d_transpose_ncw: "topi_nn_conv1d_transpose_ncw",
-            topi.nn.conv3d: "topi_nn_conv3d",
-        }
-
-        self.topi_to_schedule = {
-            topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
-                             topi.generic.schedule_conv2d_nhwc],
-            topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
-                                            topi.generic.schedule_depthwise_conv2d_nhwc],
-            topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
-            topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
-            topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
-            topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8],
-            topi.nn.dense: [topi.generic.schedule_dense],
-            topi.nn.batch_matmul: [topi.generic.schedule_batch_matmul],
-            topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
-            topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
-            topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
-            topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
-            topi.nn.conv1d_transpose_ncw: [topi.generic.schedule_conv1d_transpose_ncw],
-            topi.nn.conv3d: [topi.generic.schedule_conv3d_ndhwc],
-        }
-
-        # function reflection for tracing
-        self.func_to_reflection = {
-            topi.nn.conv2d:                 lambda x: setattr(topi.nn, 'conv2d', x),
-            topi.nn.conv2d_NCHWc:           lambda x: setattr(topi.nn, 'conv2d_NCHWc', x),
-            topi.nn.conv2d_NCHWc_int8:      lambda x: setattr(topi.nn, 'conv2d_NCHWc_int8', x),
-            topi.nn.depthwise_conv2d_nchw:  lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x),
-            topi.nn.group_conv2d_nchw:      lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
-            topi.nn.conv2d_transpose_nchw:  lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
-            topi.nn.dense:                  lambda x: setattr(topi.nn, 'dense', x),
-            topi.nn.batch_matmul:           lambda x: setattr(topi.nn, 'batch_matmul', x),
-            topi.nn.bitserial_conv2d_nchw:  lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x),
-            topi.nn.bitserial_conv2d_nhwc:  lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x),
-            topi.nn.bitserial_dense:        lambda x: setattr(topi.nn, 'bitserial_dense', x),
-            topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x),
-            topi.nn.conv1d_transpose_ncw:   lambda x: setattr(topi.nn, 'conv1d_transpose_ncw', x),
-            topi.nn.conv3d:                 lambda x: setattr(topi.nn, 'conv3d', x),
-        }
-
         self.allow_duplicate = allow_duplicate
-        self._register_topi_task()
         self.task_collection = []
-        self.wanted_topi_funcs = list(self.topi_to_task.keys())
+        self.wanted_relay_ops = None
         self.modified_funcs = []
+        self.tracing = False
 
     def __enter__(self):
         self.task_collection = []
-        self.modified_funcs = []
-
-        for topi_compute in self.wanted_topi_funcs:
-            def _local_scope(compute_func):
-                """start a scope to hold the local function in for loop"""
-
-                def _tracing_wrapper(*args, **kwargs):
-                    assert not kwargs, "Do not support extracting tuning tasks when " \
-                                       "kwargs is used in TOPI function call. " \
-                                       "Please modify it to use only positional args."
-                    key = (self.topi_to_task[compute_func], serialize_args(args))
-                    if self.allow_duplicate or key not in self.task_collection:
-                        self.task_collection.append(key)
-
-                    return compute_func(*args, **kwargs)
-
-                self.func_to_reflection[compute_func](_tracing_wrapper)
-                self.modified_funcs.append(compute_func)
-
-            _local_scope(topi_compute)
+        self.tracing = True
 
         return self
 
     def __exit__(self, exc_type, exc_val, exc_tb):
-        # revert modification
-        for func in self.modified_funcs:
-            self.func_to_reflection[func](func)
-
-    def _register_topi_task(self):
-        """register tuning wrapper for topi function"""
-        # pylint: disable=import-outside-toplevel
-        import topi
-
-        # Avoid double registration for certain targets
-        if TaskExtractEnv.registered:
-            return
-        TaskExtractEnv.registered = True
-
-        # Tuning wrapper for topi functions
-        @register("topi_nn_conv2d")
-        def _topi_nn_conv2d(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            layout = args[-2]
-            C = topi.nn.conv2d(*args, **kwargs)
-            if layout == 'NCHW':
-                s = topi.generic.schedule_conv2d_nchw([C])
-            elif layout == 'HWCN':
-                s = topi.generic.schedule_conv2d_hwcn([C])
-            elif layout == 'NHWC':
-                s = topi.generic.schedule_conv2d_nhwc([C])
-            else:
-                raise ValueError("Unsupported layout {}".format(layout))
-            return s, [A, W, C]
-
-        @register("topi_nn_depthwise_conv2d_nchw")
-        def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs)
-            s = topi.generic.schedule_depthwise_conv2d_nchw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_group_conv2d_nchw")
-        def _topi_nn_group_conv2d_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.group_conv2d_nchw(*args, **kwargs)
-            s = topi.generic.schedule_group_conv2d_nchw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_conv2d_transpose_nchw")
-        def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.conv2d_transpose_nchw(*args, **kwargs)
-            s = topi.generic.schedule_conv2d_transpose_nchw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_conv1d_transpose_ncw")
-        def _topi_nn_conv1d_transpose_ncw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.conv1d_transpose_ncw(*args, **kwargs)
-            s = topi.generic.schedule_conv1d_transpose_ncw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_conv3d")
-        def _topi_nn_conv3d(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.conv3d(*args, **kwargs)
-            s = topi.generic.schedule_conv3d_ndhwc([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_dense")
-        def _topi_nn_dense(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            if len(args) > 2:
-                data, weight, bias = args[:3]
-            else:
-                data, weight = args
-                bias = None
-            C = topi.nn.dense(*args, **kwargs)
-            s = topi.generic.schedule_dense([C])
-            if bias is not None:
-                return s, [data, weight, bias, C]
-            return s, [data, weight, C]
-
-        @register("topi_nn_batch_matmul")
-        def _topi_nn_batch_matmul(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, B = args
-            C = topi.nn.batch_matmul(A, B)
-            s = topi.generic.schedule_batch_matmul([C])
-            return s, [A, B, C]
-
-        @register("topi_nn_bitserial_conv2d_nhwc")
-        def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
-            args = deserialize_args(args)
-            C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs)
-            s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C])
-            A, W = args[:2]
-            return s, [A, W, C]
-
-        @register("topi_nn_bitserial_conv2d_nchw")
-        def _topi_bitserial_conv2d_nchw(*args, **kwargs):
-            args = deserialize_args(args)
-            C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs)
-            s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C])
-            A, W = args[:2]
-            return s, [A, W, C]
-
-        @register("topi_nn_bitserial_dense")
-        def _topi_nn_bitserial_dense(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.bitserial_dense(*args, **kwargs)
-            s = topi.generic.schedule_bitserial_dense([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_deformable_conv2d_nchw")
-        def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, Offset, W = args[:3]
-            C = topi.nn.deformable_conv2d_nchw(*args, **kwargs)
-            s = topi.generic.schedule_deformable_conv2d_nchw([C])
-            return s, [A, Offset, W, C]
-
-        @register("topi_nn_conv2d_NCHWc")
-        def _topi_nn_conv2d_NCHWc(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.conv2d_NCHWc(*args, **kwargs)
-            s = topi.generic.schedule_conv2d_NCHWc([C])
-            return s, [A, W, C]
+        self.tracing = False
 
-    def reset(self, wanted_topi_funcs):
+    def reset(self, wanted_relay_ops=None):
         """Reset task collections
 
         Parameters
         ----------
-        wanted_topi_funcs: List of function
-            The topi function to be extracted
+        wanted_relay_ops: List of relay.op.Op
+            The relay ops to be extracted
         """
         self.task_collection = []
-        self.wanted_topi_funcs = wanted_topi_funcs
+        self.wanted_relay_ops = wanted_relay_ops
+
+    def add_task(self, task_name, args):
+        """Add AutoTVM task
+
+        Parameters
+        ----------
+        task_name: str
+            AutoTVM task name.
+
+        args: tuple
+            Arguments to the TOPI function.
+        """
+        key = (task_name, serialize_args(args))
+        if self.allow_duplicate or key not in self.task_collection:
+            self.task_collection.append(key)
 
     def get_tasks(self):
         """Get collected tasks
@@ -355,26 +115,19 @@ class TaskExtractEnv:
         return TaskExtractEnv.current
 
 
-def register_topi_compute(topi_compute, target_keys, template_keys, func=None, override=False):
+def register_topi_compute(task_name, func=None):
     """Register a tunable template for a topi compute function.
 
-    After the registration, this topi compute will become a configuration dispatcher. It uses
-    all its argument as workload and dispatches configurations according to the input workload.
-
-    It also stores this "workload" to its final ComputeOp, which can be used to reconstruct
+    The registration will wrap this topi compute to take `cfg` as the first argument,
+    followed by the original argument list. It uses all its argument as workload and
+    stores this "workload" to its final ComputeOp, which can be used to reconstruct
     "workload" in the following topi_schedule call.
 
     Parameters
     ----------
-    topi_compute: GenericFunc
-        The topi compute function that will be overloaded
-    target_keys: str or list of str
-        The compilation target. The same as the argument of GenericFunc.register.
-    template_keys: str or list of str
-        The template key.
-        We might have several strategies for a single operator (e.g. direct, im2col, winograd).
-        The template key is used to identity the algorithm strategy.
-        Every operator must have a "direct" template, which is used by default.
+    task_name: str
+        The AutoTVM task name
+
     func: None or callable
         If it is None, return a decorator.
         If is callable, decorate this function.
@@ -388,81 +141,63 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None, o
     --------
     See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
     """
-    def _decorator(f):
-        targets = [target_keys] if isinstance(target_keys, str) else target_keys
-        for target_key in targets:
-            if target_key not in _REGISTERED_DISPATCHER:
-                _REGISTERED_DISPATCHER[target_key] = {}
-            if topi_compute not in _REGISTERED_DISPATCHER[target_key]:
-                @topi_compute.register(target_key)
-                @dispatcher
-                def config_dispatcher(*args, **kwargs):
-                    """override topi call as a config dispatcher"""
-                    assert not kwargs, "Do not support kwargs in template function call"
-                    return args_to_workload(args, topi_compute)
-                _REGISTERED_DISPATCHER[target_key][topi_compute] = config_dispatcher
-
-            config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_compute]
-
-            @config_dispatcher.register(template_keys, override=override)
-            def template_call(cfg, *args, **kwargs):
-                """call the topi func and attach workload to compute node"""
-                assert not kwargs, "Do not support kwargs in template function call"
-
-                if f == topi_compute.fdefault:
-                    node = f(*args, **kwargs)
-                else:
-                    node = f(cfg, *args, **kwargs)
-
-                # attach workload to return op
-                op = node.op
-                attrs = {}
-                for k, v in node.op.attrs.items():
-                    attrs[k] = v
-                attrs['workload'] = args_to_workload(args, topi_compute)
-                if isinstance(op, tensor.ComputeOp):
-                    op = tvm.te._ffi_api.ComputeOp(
-                        op.name, op.tag, attrs, op.axis, op.body)
-                elif isinstance(op, tensor.ExternOp):
-                    op = tvm.te._ffi_api.ExternOp(
-                        op.name, op.tag, attrs,
-                        op.inputs, op.input_placeholders,
-                        op.output_placeholders, op.body)
-                else:
-                    raise RuntimeError("Unsupported op type: " + str(type(op)))
-
-                if isinstance(node, tensor.Tensor):
-                    return op.output(0)
-                return [op.output(i) for i in range(len(node))]
-
-        return f
+    def _decorate(topi_compute):
+        @register_task_compute(task_name)
+        def wrapper(*args, **kwargs):
+            """wrapper function for topi compute"""
+            assert not kwargs, "Do not support kwargs in template function call"
+            task_env = TaskExtractEnv.current
+            if task_env is not None and task_env.tracing:
+                task_env.add_task(task_name, args)
+            workload = args_to_workload(args, task_name)
+            tgt = _target.Target.current()
+            cfg = DispatchContext.current.query(tgt, workload)
+            node = topi_compute(cfg, *args)
+
+            # attach workload to return op
+            op = node.op
+            attrs = {}
+            for k, v in node.op.attrs.items():
+                attrs[k] = v
+            attrs['workload'] = workload
+            if isinstance(op, tensor.ComputeOp):
+                op = tvm.te._ffi_api.ComputeOp(
+                    op.name, op.tag, attrs, op.axis, op.body)
+            elif isinstance(op, tensor.ExternOp):
+                op = tvm.te._ffi_api.ExternOp(
+                    op.name, op.tag, attrs,
+                    op.inputs, op.input_placeholders,
+                    op.output_placeholders, op.body)
+            else:
+                raise RuntimeError("Unsupported op type: " + str(type(op)))
 
-    if func:
-        _decorator(func)
+            if isinstance(node, tensor.Tensor):
+                return op.output(0)
+            return [op.output(i) for i in range(len(node))]
 
-    return _decorator
+        return wrapper
+
+    if func:
+        return _decorate(func)
+    return _decorate
 
 
-def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None, override=False):
+def register_topi_schedule(task_name, func=None):
     """Register a tunable template for a topi schedule function.
 
-    After the registration. This topi schedule will become a configuration dispatcher. It dispatches
-    configurations according to the input workload.
+    The registration will wrap this topi schedule to take `cfg` as the first argument,
+    followed by the original argument list.
 
     Note that this function will try to find "workload" from all the ComputeOp in the input.
     You can attach "workload" to your compute op by using :any:`register_topi_compute`.
 
+    The task name has to be the same as that of the corresponding topi compute function.
+
     Parameters
     ----------
-    topi_schedule: GenericFunc
-        The topi schedule function that will be overloaded
-    target_keys: str or list of str
-        The compilation target
-    template_keys: str or list of str
-        The template key.
-        We might have several strategies for a single operator (e.g. direct, im2col, winograd).
-        The template key is used to identity the algorithm strategy.
-        Every operator must have a "direct" template, which is used by default.
+    task_name: str
+        The AutoTVM task name
+
     func: None or callable
         If it is None, return a decorator.
         If is callable, decorate this function.
@@ -476,49 +211,33 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None,
     --------
     See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
     """
-    def _decorator(f):
-        targets = [target_keys] if isinstance(target_keys, str) else target_keys
-        for target_key in targets:
-            if target_key not in _REGISTERED_DISPATCHER:
-                _REGISTERED_DISPATCHER[target_key] = {}
-            if topi_schedule not in _REGISTERED_DISPATCHER[target_key]:
-                @topi_schedule.register(target_key)
-                @dispatcher
-                def config_dispatcher(outs, *args, **kwargs):
-                    """override topi call as a workload dispatcher"""
-                    def traverse(tensors):
-                        """traverse all ops to find attached workload"""
-                        for t in tensors:
-                            op = t.op
-                            if 'workload' in op.attrs:
-                                return op.attrs['workload']
-                            wkl = traverse(op.input_tensors)
-                            if wkl:
-                                return wkl
-                        return None
-
-                    outs = [outs] if isinstance(outs, tensor.Tensor) else outs
-                    workload = traverse(outs)
-
-                    if workload is None:
-                        raise RuntimeError("Cannot find workload in attribute of this schedule")
-
-                    return args_to_workload(workload)
-
-                _REGISTERED_DISPATCHER[target_key][topi_schedule] = config_dispatcher
-
-            config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_schedule]
-
-            @config_dispatcher.register(template_keys, override=override)
-            def template_call(cfg, outs, *args, **kwargs):
-                """call the schedule func"""
-                if f == topi_schedule.fdefault:
-                    return f(outs, *args, **kwargs)
-                return f(cfg, outs, *args, **kwargs)
-
-        return f
-
+    def _decorate(topi_schedule):
+        @register_task_schedule(task_name)
+        def wrapper(outs, *args, **kwargs):
+            """wrapper function for topi schedule"""
+            workload = get_workload(outs)
+            if workload is None:
+                raise RuntimeError("Cannot find workload in attribute of this schedule")
+            tgt = _target.Target.current()
+            cfg = DispatchContext.current.query(tgt, workload)
+            return topi_schedule(cfg, outs, *args, **kwargs)
+        return wrapper
     if func:
-        _decorator(func)
-
-    return _decorator
+        return _decorate(func)
+    return _decorate
+
+
+def get_workload(outs):
+    """Retrieve the workload from outputs"""
+    def traverse(tensors):
+        """traverse all ops to find attached workload"""
+        for t in tensors:
+            op = t.op
+            if 'workload' in op.attrs:
+                return args_to_workload(op.attrs['workload'])
+            wkl = traverse(op.input_tensors)
+            if wkl:
+                return wkl
+        return None
+    outs = [outs] if isinstance(outs, tensor.Tensor) else outs
+    return traverse(outs)
index e1a7d86..f13ba52 100644 (file)
@@ -46,16 +46,16 @@ AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub
 
 # the version of each package
 PACKAGE_VERSION = {
-    'arm_cpu':          "v0.04",
-    'llvm':             "v0.03",
+    'arm_cpu':          "v0.06",
+    'llvm':             "v0.04",
 
-    'cuda':             "v0.06",
-    'rocm':             "v0.03",
-    'opencl':           "v0.03",
-    'mali':             "v0.05",
-    'intel_graphics':   "v0.01",
+    'cuda':             "v0.08",
+    'rocm':             "v0.04",
+    'opencl':           "v0.04",
+    'mali':             "v0.06",
+    'intel_graphics':   "v0.02",
 
-    'vta':              "v0.06",
+    'vta':              "v0.08",
 }
 
 logger = logging.getLogger('autotvm')
@@ -189,7 +189,7 @@ def download_package(tophub_location, package_name):
 # global cache for load_reference_log
 REFERENCE_LOG_CACHE = {}
 
-def load_reference_log(backend, model, workload_name, template_key):
+def load_reference_log(backend, model, workload_name):
     """ Load reference log from TopHub to support fallback in template.
     Template will use these reference logs to choose fallback config.
 
@@ -201,8 +201,6 @@ def load_reference_log(backend, model, workload_name, template_key):
         The name of the device model
     workload_name: str
         The name of the workload. (The first item in the workload tuple)
-    template_key: str
-        The template key
     """
 
     backend = _alias(backend)
@@ -211,7 +209,7 @@ def load_reference_log(backend, model, workload_name, template_key):
     filename = os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)
 
     global REFERENCE_LOG_CACHE
-    key = (backend, model, workload_name, template_key)
+    key = (backend, model, workload_name)
 
     if key not in REFERENCE_LOG_CACHE:
         tmp = []
@@ -233,8 +231,7 @@ def load_reference_log(backend, model, workload_name, template_key):
                 model = max(counts.items(), key=lambda k: k[1])[0]
 
             for inp, res in load_from_file(filename):
-                if (model == inp.target.model and inp.task.workload[0] == workload_name and
-                        inp.config.template_key == template_key):
+                if model == inp.target.model and inp.task.workload[0] == workload_name:
                     tmp.append((inp, res))
         REFERENCE_LOG_CACHE[key] = tmp
 
index 882b0ad..3052448 100644 (file)
@@ -219,8 +219,7 @@ class XGBoostCostModel(CostModel):
         # filter data, only pick the data with a same task
         data = []
         for inp, res in records:
-            if inp.task.name == self.task.name and \
-                            inp.config.template_key == self.task.config_space.template_key:
+            if inp.task.name == self.task.name:
                 data.append((inp, res))
 
         logger.debug("XGB load %d entries from history log file", len(data))
index 4eedd23..6466dff 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=len-as-condition,no-else-return,invalid-name
 """Backend code generation engine."""
 from __future__ import absolute_import
 
+import logging
+import numpy as np
+import tvm
 from ..base import register_relay_node, Object
 from ... import target as _target
+from ... import autotvm
 from .. import expr as _expr
+from .. import op as _op
+from .. import ty as _ty
 from . import _backend
 
+logger = logging.getLogger('compile_engine')
+
+
 @register_relay_node
-class CachedFunc(Object):
-    """Low-level tensor function to back a relay primitive function.
-    """
+class LoweredOutput(Object):
+    """Lowered output"""
+    def __init__(self, outputs, implement):
+        self.__init_handle_by_constructor__(
+            _backend._make_LoweredOutput, outputs, implement)
 
 
 @register_relay_node
@@ -63,6 +75,191 @@ def _get_cache_key(source_func, target):
     return source_func
 
 
+def get_shape(shape):
+    """Convert the shape to correct dtype and vars."""
+    ret = []
+    for dim in shape:
+        if isinstance(dim, tvm.expr.IntImm):
+            val = int(dim)
+            assert val <= np.iinfo(np.int32).max
+            ret.append(tvm.expr.IntImm("int32", val))
+        elif isinstance(dim, tvm.expr.Any):
+            ret.append(tvm.var("any_dim", "int32"))
+        else:
+            ret.append(dim)
+    return ret
+
+
+def get_valid_implementations(op, attrs, inputs, out_type, target):
+    """Get all valid implementations from the op strategy.
+
+    Note that this function doesn't support op with symbolic input shapes.
+
+    Parameters
+    ----------
+    op : relay.op.Op
+        Relay operator.
+
+    attrs : object
+        The op attribute.
+
+    inputs : List[tvm.Tensor]
+        Input tensors to the op.
+
+    out_type : relay.Type
+        The output type.
+
+    target : tvm.target.Target
+        The target to compile the op.
+
+    Returns
+    -------
+    ret : List[relay.op.OpImplementation]
+        The list of all valid op implementations.
+    """
+    fstrategy = op.get_attr("FTVMStrategy")
+    assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
+    with target:
+        strategy = fstrategy(attrs, inputs, out_type, target)
+    analyzer = tvm.arith.Analyzer()
+    ret = []
+    for spec in strategy.specializations:
+        if spec.condition:
+            # check if all the clauses in the specialized condition are true
+            flag = True
+            for clause in spec.condition.clauses:
+                clause = analyzer.canonical_simplify(clause)
+                if isinstance(clause, tvm.expr.IntImm) and clause.value:
+                    continue
+                flag = False
+                break
+            if flag:
+                for impl in spec.implementations:
+                    ret.append(impl)
+        else:
+            for impl in spec.implementations:
+                ret.append(impl)
+    return ret
+
+
+def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True):
+    """Select the best implementation from the op strategy.
+
+    If use_autotvm is True, it'll first try to find the best implementation
+    based on AutoTVM profile results. If no AutoTVM profile result is found,
+    it'll choose the implementation with highest plevel.
+
+    If use_autotvm is False, it'll directly choose the implementation with
+    highest plevel.
+
+    Note that this function doesn't support op with symbolic input shapes.
+
+    Parameters
+    ----------
+    op : relay.op.Op
+        Relay operator.
+
+    attrs : object
+        The op attribute.
+
+    inputs : List[tvm.Tensor]
+        Input tensors to the op.
+
+    out_type : relay.Type
+        The output type.
+
+    target : tvm.target.Target
+        The target to compile the op.
+
+    use_autotvm : bool
+        Whether query AutoTVM to pick the best.
+
+    Returns
+    -------
+    ret : tuple(relay.op.OpImplementation, List[tvm.Tensor])
+        The best op implementation and the corresponding output tensors.
+    """
+    all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
+
+    best_plevel_impl = None
+    for impl in all_impls:
+        if best_plevel_impl is None or impl.plevel > best_plevel_impl.plevel:
+            best_plevel_impl = impl
+    if not use_autotvm:
+        outs = best_plevel_impl.compute(attrs, inputs, out_type)
+        return best_plevel_impl, outs
+
+    outputs = {}
+    best_autotvm_impl = None
+    best_cfg = None
+    dispatch_ctx = autotvm.task.DispatchContext.current
+    for impl in all_impls:
+        outs = impl.compute(attrs, inputs, out_type)
+        outputs[impl] = outs
+        workload = autotvm.task.get_workload(outs)
+        if workload is None:
+            continue
+        cfg = dispatch_ctx.query(target, workload)
+        if cfg.is_fallback:
+            # It's a fallback config
+            continue
+        if best_cfg is None or best_cfg.cost > cfg.cost:
+            best_autotvm_impl = impl
+            best_cfg = cfg
+    if best_autotvm_impl:
+        return best_autotvm_impl, outputs[best_autotvm_impl]
+    return best_plevel_impl, outputs[best_plevel_impl]
+
+
+@tvm._ffi.register_func("relay.backend.lower_call")
+def lower_call(call, inputs, target):
+    """Lower the call expression to op implementation and tensor outputs."""
+    assert isinstance(call.op, _op.Op)
+    op = call.op
+
+    # Prepare the call_node->checked_type(). For the call node inputs, we ensure that
+    # the shape is Int32. Following code ensures the same for the output as well.
+    # TODO(@icemelon9): Support recursive tuple
+    ret_type = call.checked_type
+    if isinstance(ret_type, _ty.TensorType):
+        ret_type = _ty.TensorType(get_shape(ret_type.shape), ret_type.dtype)
+    elif isinstance(ret_type, _ty.TupleType):
+        new_fields = []
+        for field in ret_type.fields:
+            if isinstance(field, _ty.TensorType):
+                new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype))
+            else:
+                new_fields.append(field)
+        ret_type = _ty.TupleType(new_fields)
+
+    is_dyn = _ty.type_has_any(call.checked_type)
+    for arg in call.args:
+        is_dyn = is_dyn or _ty.type_has_any(arg.checked_type)
+
+    # check if in the AutoTVM tracing mode, and disable if op is not in wanted list
+    env = autotvm.task.TaskExtractEnv.current
+    reenable_tracing = False
+    if env is not None and env.tracing:
+        if env.wanted_relay_ops is not None and op not in env.wanted_relay_ops:
+            env.tracing = False
+            reenable_tracing = True
+
+    if not is_dyn:
+        best_impl, outputs = select_implementation(
+            op, call.attrs, inputs, ret_type, target)
+        logger.info("Use implementation %s for op %s", best_impl.name, op.name)
+    else:
+        # TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes.
+        #   Currently, we just use the implementation with highest plevel
+        best_impl, outputs = select_implementation(
+            op, call.attrs, inputs, ret_type, target, use_autotvm=False)
+
+    # re-enable AutoTVM tracing
+    if reenable_tracing:
+        env.tracing = True
+    return LoweredOutput(outputs, best_impl)
+
+
 @register_relay_node
 class CompileEngine(Object):
     """CompileEngine to get lowered code.
index f492c74..8d69239 100644 (file)
@@ -131,22 +131,22 @@ class ExprVisitor(ExprFunctor):
 
     The default behavior recursively traverses the AST.
     """
-    def visit_tuple(self, t):
-        for x in t.fields:
+    def visit_tuple(self, tup):
+        for x in tup.fields:
             self.visit(x)
 
-    def visit_call(self, c):
-        self.visit(c.op)
-        for a in c.args:
+    def visit_call(self, call):
+        self.visit(call.op)
+        for a in call.args:
             self.visit(a)
 
-    def visit_var(self, v):
+    def visit_var(self, var):
         pass
 
-    def visit_let(self, l):
-        self.visit(l.var)
-        self.visit(l.value)
-        self.visit(l.body)
+    def visit_let(self, let):
+        self.visit(let.var)
+        self.visit(let.value)
+        self.visit(let.body)
 
     def visit_function(self, f):
         self.visit(f.body)
index 587b076..5532e3a 100644 (file)
@@ -311,6 +311,7 @@ def _conv(opname):
             flip_layout = True
 
         if attr['data_format'] == 'NHWC':
+            in_channels = input_shape[3]
             kernel_h, kernel_w, _, depth_mult = weights_shape
             attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
             if opname == 'conv':
@@ -324,6 +325,7 @@ def _conv(opname):
                 attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
             attr['strides'] = (attr['strides'][1], attr['strides'][2])
         elif attr['data_format'] == 'NCHW':
+            in_channels = input_shape[1]
             _, depth_mult, kernel_h, kernel_w = weights_shape
             attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
             if opname == 'conv':
@@ -344,7 +346,7 @@ def _conv(opname):
             raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
 
         if opname == 'depthwise':
-            attr['groups'] = attr['channels']
+            attr['groups'] = in_channels
 
         # Fix padding
         attr['padding'] = attr['padding'].decode("utf-8")
index dd35871..352bc63 100644 (file)
@@ -1156,7 +1156,7 @@ class OperatorConverter(object):
 
         if is_depthwise_conv:
             params['channels'] = int(in_channels)
-            params['groups'] = int(in_channels)
+            params['groups'] = int(input_c)
             params['kernel_layout'] = 'HWOI'
         else:
             params['channels'] = int(output_channels)
index d61c6f1..f8e9811 100644 (file)
@@ -28,8 +28,8 @@ from .backend import compile_engine
 
 
 def is_primitive(call):
-    return hasattr(call.op, 'attrs') and hasattr(call.op.attrs, 'Primitive') and \
-        int(call.op.attrs.Primitive) == 1
+    return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \
+           hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1
 
 # TODO(@jroesch): port to c++ and unify with existing code
 class LinearizeRetType:
index bcd58ba..7427c63 100644 (file)
 #pylint: disable=wildcard-import, redefined-builtin
 """Relay core operators."""
 # operator defs
-from .op import get, register, register_schedule, register_compute, register_gradient, \
+from .op import get, register, register_compute, register_gradient, \
     register_pattern, register_alter_op_layout, register_legalize, \
-    schedule_injective, Op, OpPattern, debug
+    Op, OpPattern, OpStrategy, debug
+from . import strategy
 
 # Operators
 from .reduce import *
index 09746be..e1e6fd3 100644 (file)
 # pylint: disable=invalid-name,unused-argument
 from __future__ import absolute_import
 
-import topi
-from topi.util import get_const_int
-from ..op import OpPattern, register_compute, register_schedule, register_pattern
-
-
-@register_schedule("argsort")
-def schedule_argsort(_, outs, target):
-    """Schedule definition of argsort"""
-    with target:
-        return topi.generic.schedule_argsort(outs)
-
-
-@register_compute("argsort")
-def compute_argsort(attrs, inputs, _, target):
-    """Compute definition of argsort"""
-    axis = get_const_int(attrs.axis)
-    is_ascend = bool(get_const_int(attrs.is_ascend))
-    dtype = attrs.dtype
-    return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
-
+from . import strategy
+from .op import OpPattern, register_pattern
+from .op import register_strategy
 
+# argsort
+register_strategy("argsort", strategy.argsort_strategy)
 register_pattern("argsort", OpPattern.OPAQUE)
 
-
-@register_schedule("topk")
-def schedule_topk(_, outs, target):
-    """Schedule definition of argsort"""
-    with target:
-        return topi.generic.schedule_topk(outs)
-
-
-@register_compute("topk")
-def compute_topk(attrs, inputs, _, target):
-    """Compute definition of argsort"""
-    k = get_const_int(attrs.k)
-    axis = get_const_int(attrs.axis)
-    ret_type = attrs.ret_type
-    is_ascend = bool(get_const_int(attrs.is_ascend))
-    dtype = attrs.dtype
-    out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype)
-    out = out if isinstance(out, list) else [out]
-    return out
-
-
+# topk
+register_strategy("topk", strategy.topk_strategy)
 register_pattern("topk", OpPattern.OPAQUE)
index 43f71c0..9d52ed3 100644 (file)
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 
-import topi
-
 from topi.util import get_const_int, get_const_tuple
 from . import op as _reg
 from ...api import convert
 from ...hybrid import script
 
-
-def _schedule_reduce(_, outs, target):
-    """Generic schedule for reduce"""
-    with target:
-        return topi.generic.schedule_reduce(outs)
-
-
-_reg.register_schedule("argmax", _schedule_reduce)
-_reg.register_schedule("argmin", _schedule_reduce)
-_reg.register_schedule("sum", _schedule_reduce)
-_reg.register_schedule("all", _schedule_reduce)
-_reg.register_schedule("any", _schedule_reduce)
-_reg.register_schedule("max", _schedule_reduce)
-_reg.register_schedule("min", _schedule_reduce)
-_reg.register_schedule("prod", _schedule_reduce)
-_reg.register_schedule("mean", _schedule_reduce)
-_reg.register_schedule("variance", _schedule_reduce)
-_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
-_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)
-
+_reg.register_reduce_schedule("argmax")
+_reg.register_reduce_schedule("argmin")
+_reg.register_reduce_schedule("sum")
+_reg.register_reduce_schedule("all")
+_reg.register_reduce_schedule("any")
+_reg.register_reduce_schedule("max")
+_reg.register_reduce_schedule("min")
+_reg.register_reduce_schedule("prod")
+_reg.register_reduce_schedule("mean")
+_reg.register_reduce_schedule("variance")
 
 def _create_axis_record(attrs, inputs):
     axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis))
index c1d02bd..7c8ccb7 100644 (file)
 from __future__ import absolute_import
 import topi
 from topi.util import get_const_tuple
-from .op import register_compute, register_schedule, register_pattern, register_shape_func
-from .op import schedule_injective, OpPattern
+from .op import register_compute, register_shape_func
+from .op import register_broadcast_schedule, register_injective_schedule
+from .op import register_pattern, OpPattern
 from ...hybrid import script
 from ...api import convert
 
-schedule_broadcast = schedule_injective
-schedule_elemwise = schedule_injective
-
-register_schedule("log", schedule_broadcast)
-register_schedule("cos", schedule_broadcast)
-register_schedule("sin", schedule_broadcast)
-register_schedule("atan", schedule_broadcast)
-register_schedule("exp", schedule_broadcast)
-register_schedule("erf", schedule_broadcast)
-register_schedule("sqrt", schedule_broadcast)
-register_schedule("rsqrt", schedule_broadcast)
-register_schedule("sigmoid", schedule_broadcast)
-register_schedule("floor", schedule_broadcast)
-register_schedule("ceil", schedule_broadcast)
-register_schedule("trunc", schedule_broadcast)
-register_schedule("round", schedule_broadcast)
-register_schedule("sign", schedule_broadcast)
-register_schedule("abs", schedule_broadcast)
-register_schedule("tanh", schedule_broadcast)
-register_schedule("logical_not", schedule_broadcast)
-register_schedule("bitwise_not", schedule_broadcast)
-register_schedule("negative", schedule_broadcast)
-register_schedule("copy", schedule_broadcast)
-
-register_schedule("add", schedule_broadcast)
-register_schedule("subtract", schedule_broadcast)
-register_schedule("multiply", schedule_broadcast)
-register_schedule("divide", schedule_broadcast)
-register_schedule("floor_divide", schedule_broadcast)
-register_schedule("power", schedule_injective)
-register_schedule("mod", schedule_broadcast)
-register_schedule("floor_mod", schedule_broadcast)
-register_schedule("logical_and", schedule_broadcast)
-register_schedule("logical_or", schedule_broadcast)
-register_schedule("bitwise_and", schedule_broadcast)
-register_schedule("bitwise_or", schedule_broadcast)
-register_schedule("bitwise_xor", schedule_broadcast)
-register_schedule("equal", schedule_broadcast)
-register_schedule("not_equal", schedule_broadcast)
-register_schedule("less", schedule_broadcast)
-register_schedule("less_equal", schedule_broadcast)
-register_schedule("greater", schedule_broadcast)
-register_schedule("greater_equal", schedule_broadcast)
-register_schedule("maximum", schedule_injective)
-register_schedule("minimum", schedule_injective)
-register_schedule("right_shift", schedule_injective)
-register_schedule("left_shift", schedule_injective)
-register_schedule("shape_of", schedule_injective)
+
+register_broadcast_schedule("log")
+register_broadcast_schedule("cos")
+register_broadcast_schedule("sin")
+register_broadcast_schedule("atan")
+register_broadcast_schedule("exp")
+register_broadcast_schedule("erf")
+register_broadcast_schedule("sqrt")
+register_broadcast_schedule("rsqrt")
+register_broadcast_schedule("sigmoid")
+register_broadcast_schedule("floor")
+register_broadcast_schedule("ceil")
+register_broadcast_schedule("trunc")
+register_broadcast_schedule("round")
+register_broadcast_schedule("sign")
+register_broadcast_schedule("abs")
+register_broadcast_schedule("tanh")
+register_broadcast_schedule("add")
+register_broadcast_schedule("subtract")
+register_broadcast_schedule("multiply")
+register_broadcast_schedule("divide")
+register_broadcast_schedule("floor_divide")
+register_broadcast_schedule("power")
+register_broadcast_schedule("copy")
+register_broadcast_schedule("logical_not")
+register_broadcast_schedule("logical_and")
+register_broadcast_schedule("logical_or")
+register_broadcast_schedule("bitwise_not")
+register_broadcast_schedule("bitwise_and")
+register_broadcast_schedule("bitwise_or")
+register_broadcast_schedule("bitwise_xor")
+register_broadcast_schedule("negative")
+register_broadcast_schedule("mod")
+register_broadcast_schedule("floor_mod")
+register_broadcast_schedule("equal")
+register_broadcast_schedule("not_equal")
+register_broadcast_schedule("less")
+register_broadcast_schedule("less_equal")
+register_broadcast_schedule("greater")
+register_broadcast_schedule("greater_equal")
+register_injective_schedule("maximum")
+register_injective_schedule("minimum")
+register_injective_schedule("right_shift")
+register_injective_schedule("left_shift")
+register_injective_schedule("shape_of")
 
 # zeros
 @register_compute("zeros")
-def zeros_compute(attrs, inputs, output_type, target):
+def zeros_compute(attrs, inputs, output_type):
     assert not inputs
     return [topi.full(output_type.shape, output_type.dtype, 0.0)]
 
-register_schedule("zeros", schedule_broadcast)
+register_broadcast_schedule("zeros")
 register_pattern("zeros", OpPattern.ELEMWISE)
 
 # zeros_like
 @register_compute("zeros_like")
-def zeros_like_compute(attrs, inputs, output_type, target):
+def zeros_like_compute(attrs, inputs, output_type):
     assert len(inputs) == 1
     return [topi.full_like(inputs[0], 0.0)]
 
-register_schedule("zeros_like", schedule_broadcast)
+register_broadcast_schedule("zeros_like")
 
 # ones
 @register_compute("ones")
-def ones_compute(attrs, inputs, output_type, target):
+def ones_compute(attrs, inputs, output_type):
     assert not inputs
     return [topi.full(output_type.shape, output_type.dtype, 1.0)]
 
-register_schedule("ones", schedule_broadcast)
+register_broadcast_schedule("ones")
 register_pattern("ones", OpPattern.ELEMWISE)
 
 # ones_like
 @register_compute("ones_like")
-def ones_like(attrs, inputs, output_type, target):
+def ones_like_compute(attrs, inputs, output_type):
     assert len(inputs) == 1
     return [topi.full_like(inputs[0], 1.0)]
 
-register_schedule("ones_like", schedule_broadcast)
+register_broadcast_schedule("ones_like")
 
 # clip
 @register_compute("clip")
-def clip_compute(attrs, inputs, output_type, target):
+def clip_compute(attrs, inputs, output_type):
     assert len(inputs) == 1
     return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
 
-register_schedule("clip", schedule_elemwise)
+register_injective_schedule("clip")
 
 @script
 def _cast_shape_function(x):
@@ -198,6 +196,7 @@ register_shape_func("mod", False, broadcast_shape_func)
 register_shape_func("floor_mod", False, broadcast_shape_func)
 register_shape_func("logical_and", False, broadcast_shape_func)
 register_shape_func("logical_or", False, broadcast_shape_func)
+register_shape_func("bitwise_not", False, broadcast_shape_func)
 register_shape_func("bitwise_and", False, broadcast_shape_func)
 register_shape_func("bitwise_or", False, broadcast_shape_func)
 register_shape_func("bitwise_xor", False, broadcast_shape_func)
index e6053b8..42c9434 100644 (file)
@@ -21,52 +21,74 @@ import tvm
 import topi
 from topi.util import get_const_int, get_const_tuple
 from . import op as _reg
-from ._reduce import _schedule_reduce
+from . import strategy
 from .op import OpPattern
 from ...hybrid import script
 from ...api import convert
 
-schedule_injective = _reg.schedule_injective
-schedule_broadcast = _reg.schedule_injective
-schedule_concatenate = _reg.schedule_concatenate
-
-
-_reg.register_schedule("collapse_sum_like", _schedule_reduce)
-_reg.register_schedule("broadcast_to", schedule_broadcast)
-_reg.register_schedule("broadcast_to_like", schedule_broadcast)
-_reg.register_schedule("expand_dims", schedule_broadcast)
-_reg.register_schedule("squeeze", schedule_injective)
-_reg.register_schedule("reshape", schedule_injective)
-_reg.register_schedule("reshape_like", schedule_injective)
-_reg.register_schedule("full", schedule_injective)
-_reg.register_schedule("full_like", schedule_injective)
-_reg.register_schedule("arange", schedule_injective)
-_reg.register_schedule("reverse", schedule_injective)
-_reg.register_schedule("repeat", schedule_broadcast)
-_reg.register_schedule("tile", schedule_broadcast)
-_reg.register_schedule("cast", schedule_injective)
-_reg.register_schedule("cast_like", schedule_injective)
-_reg.register_schedule("reinterpret", schedule_injective)
-_reg.register_schedule("strided_slice", schedule_injective)
-_reg.register_schedule("strided_set", schedule_injective)
-_reg.register_schedule("slice_like", schedule_injective)
-_reg.register_schedule("split", schedule_injective)
-_reg.register_schedule("take", schedule_injective)
-_reg.register_schedule("transpose", schedule_injective)
-_reg.register_schedule("where", schedule_broadcast)
-_reg.register_schedule("stack", schedule_injective)
-_reg.register_schedule("concatenate", schedule_concatenate)
-_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
-_reg.register_schedule("gather_nd", schedule_injective)
-_reg.register_schedule("sequence_mask", schedule_injective)
-_reg.register_schedule("one_hot", schedule_injective)
+_reg.register_broadcast_schedule("broadcast_to")
+_reg.register_broadcast_schedule("broadcast_to_like")
+_reg.register_broadcast_schedule("expand_dims")
+_reg.register_broadcast_schedule("repeat")
+_reg.register_broadcast_schedule("tile")
+_reg.register_broadcast_schedule("where")
+_reg.register_injective_schedule("squeeze")
+_reg.register_injective_schedule("reshape")
+_reg.register_injective_schedule("reshape_like")
+_reg.register_injective_schedule("full")
+_reg.register_injective_schedule("full_like")
+_reg.register_injective_schedule("arange")
+_reg.register_injective_schedule("reverse")
+_reg.register_injective_schedule("cast")
+_reg.register_injective_schedule("cast_like")
+_reg.register_injective_schedule("reinterpret")
+_reg.register_injective_schedule("strided_slice")
+_reg.register_injective_schedule("slice_like")
+_reg.register_injective_schedule("split")
+_reg.register_injective_schedule("take")
+_reg.register_injective_schedule("transpose")
+_reg.register_injective_schedule("stack")
+_reg.register_injective_schedule("_contrib_reverse_reshape")
+_reg.register_injective_schedule("gather_nd")
+_reg.register_injective_schedule("sequence_mask")
+_reg.register_injective_schedule("one_hot")
+_reg.register_reduce_schedule("collapse_sum_like")
+
+# concatenate
+_reg.register_schedule("concatenate", strategy.schedule_concatenate)
+
+# strided_set
+@_reg.register_compute("strided_set")
+def compute_strided_set(attrs, inputs, output_type):
+    """Compute definition of strided_set"""
+    return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])]
 
+_reg.register_injective_schedule("strided_set")
 
 # layout_transform
-_reg.register_schedule("layout_transform", schedule_injective)
+_reg.register_injective_schedule("layout_transform")
 _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
 
-# shape func
+# argwhere
+@_reg.register_compute("argwhere")
+def compute_argwhere(attrs, inputs, output_type):
+    """Compute definition of argwhere"""
+    output_shape = []
+    for s in output_type.shape:
+        if hasattr(s, "value"):
+            output_shape.append(s)
+        else:
+            # see Any, replace it with a var
+            output_shape.append(tvm.var("any_dim", "int32"))
+    new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
+    return [topi.argwhere(new_output_type, inputs[0])]
+
+_reg.register_schedule("argwhere", strategy.schedule_argwhere)
+
+#####################
+#  Shape functions  #
+#####################
+
 @script
 def _arange_shape_func(start, stop, step):
     out = output_tensor((1,), "int64")
@@ -284,31 +306,6 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
         return [_argwhere_shape_func_5d(inputs[0])]
     return ValueError("Does not support rank higher than 5 in argwhere")
 
-@_reg.register_schedule("argwhere")
-def schedule_argwhere(_, outs, target):
-    """Schedule definition of argwhere"""
-    with target:
-        return topi.generic.schedule_argwhere(outs)
-
-
-@_reg.register_compute("argwhere")
-def compute_argwhere(attrs, inputs, output_type, _):
-    """Compute definition of argwhere"""
-    output_shape = []
-    for s in output_type.shape:
-        if hasattr(s, "value"):
-            output_shape.append(s)
-        else:
-            # see Any, replace it with a var
-            output_shape.append(tvm.var("any_dim", "int32"))
-    new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
-    return [topi.argwhere(new_output_type, inputs[0])]
-
-@_reg.register_compute("strided_set")
-def compute_strided_set(attrs, inputs, output_type, _):
-    """Compute definition of strided_set"""
-    return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])]
-
 @script
 def _layout_transform_shape_func(data_shape,
                                  out_layout_len,
index 586c300..7bd5262 100644 (file)
@@ -19,7 +19,7 @@ from tvm.runtime import ndarray as _nd
 from tvm.runtime import TVMContext as _TVMContext
 
 from . import _make
-from ..op import register_schedule, schedule_injective
+from .. import op as reg
 
 
 def on_device(data, device):
@@ -79,7 +79,7 @@ def checkpoint(data):
     """
     return _make.checkpoint(data)
 
-register_schedule("annotation.checkpoint", schedule_injective)
+reg.register_injective_schedule("annotation.checkpoint")
 
 
 def compiler_begin(data, compiler):
index 4b55880..3927cef 100644 (file)
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 
-import topi
 from .. import op as reg
-from ..op import schedule_injective, OpPattern
+from .. import strategy
+from ..op import OpPattern
 
 
 # adaptive_max_pool2d
-@reg.register_schedule("contrib.adaptive_max_pool2d")
-def schedule_adaptive_max_pool2d(_, outs, target):
-    """Schedule definition of adaptive_max_pool2d"""
-    with target:
-        return topi.generic.schedule_adaptive_pool(outs)
-
+reg.register_schedule("contrib.adaptive_max_pool2d", strategy.schedule_adaptive_pool)
 reg.register_pattern("contrib.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # adaptive_avg_pool2d
-@reg.register_schedule("contrib.adaptive_avg_pool2d")
-def schedule_adaptive_avg_pool2d(_, outs, target):
-    """Schedule definition of adaptive_avg_pool2d"""
-    with target:
-        return topi.generic.schedule_adaptive_pool(outs)
-
+reg.register_schedule("contrib.adaptive_avg_pool2d", strategy.schedule_adaptive_pool)
 reg.register_pattern("contrib.adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 # relay.contrib.ndarray_size
-reg.register_schedule("contrib.ndarray_size", schedule_injective)
+reg.register_injective_schedule("contrib.ndarray_size")
index 89fde6d..b98b2bc 100644 (file)
@@ -20,13 +20,10 @@ from __future__ import absolute_import
 
 import topi
 from .. import op as reg
-from ..op import schedule_injective
 
 # resize
-reg.register_schedule("image.resize", schedule_injective)
-
 @reg.register_compute("image.resize")
-def compute_resize(attrs, inputs, out_type, target):
+def compute_resize(attrs, inputs, out_type):
     size = attrs.size
     layout = attrs.layout
     method = attrs.method
@@ -34,12 +31,12 @@ def compute_resize(attrs, inputs, out_type, target):
     out_dtype = attrs.out_dtype
     return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)]
 
+reg.register_injective_schedule("image.resize")
 
-# crop and resize
-reg.register_schedule("image.crop_and_resize", schedule_injective)
 
+# crop and resize
 @reg.register_compute("image.crop_and_resize")
-def compute_crop_and_resize(attrs, inputs, out_type, target):
+def compute_crop_and_resize(attrs, inputs, out_type):
     crop_size = attrs.crop_size
     layout = attrs.layout
     method = attrs.method
@@ -48,3 +45,5 @@ def compute_crop_and_resize(attrs, inputs, out_type, target):
     return [topi.image.crop_and_resize(inputs[0], inputs[1], inputs[2],
                                        crop_size, layout, method,
                                        extrapolation_value, out_dtype)]
+
+reg.register_injective_schedule("image.crop_and_resize")
index 3fdafd5..97a5fa6 100644 (file)
@@ -21,253 +21,79 @@ from __future__ import absolute_import
 import topi
 from topi.util import get_const_tuple
 from .. import op as reg
-from ..op import OpPattern, schedule_injective
+from .. import strategy
+from ..op import OpPattern
 from .._tensor import elemwise_shape_func
 from ....api import convert
 from ....hybrid import script
 
 # relu
-reg.register_schedule("nn.relu", schedule_injective)
+reg.register_broadcast_schedule("nn.relu")
 reg.register_pattern("nn.relu", OpPattern.ELEMWISE)
 
-# softmax
-@reg.register_schedule("nn.softmax")
-def schedule_softmax(_, outputs, target):
-    """Schedule definition of softmax"""
-    with target:
-        return topi.generic.schedule_softmax(outputs)
-
 
+# softmax
+reg.register_schedule("nn.softmax", strategy.schedule_softmax)
 reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
 
-schedule_broadcast = schedule_injective
-
-
-@reg.register_schedule("nn.log_softmax")
-def schedule_log_softmax(_, outputs, target):
-    """Schedule definition of log_softmax"""
-    with target:
-        return topi.generic.schedule_softmax(outputs)
-
 
+# log_softmax
+reg.register_schedule("nn.log_softmax", strategy.schedule_softmax)
 reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
 
 
 # dense
-@reg.register_compute("nn.dense")
-def compute_dense(attrs, inputs, out_type, target):
-    """Compute definition of dense"""
-    out_dtype = attrs.out_dtype
-    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
-    return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
-
-
-@reg.register_schedule("nn.dense")
-def schedule_dense(attrs, outputs, target):
-    """Schedule definition of dense"""
-    with target:
-        return topi.generic.schedule_dense(outputs)
-
-
+reg.register_strategy("nn.dense", strategy.dense_strategy)
 reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
+# fifo_buffer
 @reg.register_compute('nn.fifo_buffer')
-def compute_fifo_buffer(attrs, inputs, out_type, target):
+def compute_fifo_buffer(attrs, inputs, out_type):
     return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int('axis'))]
 
-
-@reg.register_schedule('nn.fifo_buffer')
-def schedule_fifo_buffer(attrs, outputs, target):
-    with target:
-        return topi.generic.schedule_injective(outputs)
-
-
+reg.register_injective_schedule("nn.fifo_buffer")
 reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)
 
 
 # batch_matmul
-@reg.register_compute("nn.batch_matmul")
-def compute_batch_matmul(attrs, inputs, out_type, target):
-    """Compute definition of batch_matmul"""
-    with target:
-        return [topi.nn.batch_matmul(inputs[0], inputs[1])]
-
-
-@reg.register_schedule("nn.batch_matmul")
-def schedule_batch_matmul(attrs, outputs, target):
-    """Schedule definition of batch_matmul"""
-    with target:
-        return topi.generic.schedule_batch_matmul(outputs)
-
-
+reg.register_strategy("nn.batch_matmul", strategy.batch_matmul_strategy)
 reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
+
 # sparse_dense
 @reg.register_compute("nn.sparse_dense")
-def compute_sparse_dense(attrs, inputs, out_type, target):
+def compute_sparse_dense(attrs, inputs, out_type):
     """Compute definition of sparse_dense"""
     return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
 
-@reg.register_schedule("nn.sparse_dense")
-def schedule_sparse_dense(attrs, outputs, target):
-    """Schedule definition of batch_matmul"""
-    with target:
-        return topi.generic.schedule_sparse_dense(outputs)
-
+reg.register_schedule("nn.sparse_dense", strategy.schedule_sparse_dense)
 reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
+
 # sparse_transpose
 @reg.register_compute("nn.sparse_transpose")
-def compute_sparse_transpose(attrs, inputs, out_type, target):
+def compute_sparse_transpose(attrs, inputs, out_type):
     """Compute definition of sparse_transpose"""
     return topi.nn.sparse_transpose(inputs[0], inputs[1], inputs[2])
 
-@reg.register_schedule("nn.sparse_transpose")
-def schedule_sparse_transpose(attrs, outputs, target):
-    """Schedule definition of batch_matmul"""
-    with target:
-        return topi.generic.schedule_sparse_transpose(outputs)
-
+reg.register_schedule("nn.sparse_transpose", strategy.schedule_sparse_transpose)
 reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
-# Conv1D
-@reg.register_compute("nn.conv1d")
-def compute_conv1d(attrs, inputs, out_type, target):
-    """Compute definition of conv1d"""
-    strides = get_const_tuple(attrs.strides)
-    padding = get_const_tuple(attrs.padding)
-    dilation = get_const_tuple(attrs.dilation)
-    layout = attrs.data_layout
-    out_dtype = attrs.out_dtype
-    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                 else out_dtype)
-
-    assert layout in ["NCW", "NWC"]
-    if dilation[0] < 1:
-        raise ValueError("dilation should be a positive value")
-
-    return [topi.nn.conv1d(inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)]
-
-
-@reg.register_schedule("nn.conv1d")
-def schedule_conv1d(attrs, outs, target):
-    """Schedule definition of conv1d"""
-    layout = attrs.data_layout
-
-    with target:
-        if layout == "NCW":
-            return topi.generic.schedule_conv1d_ncw(outs)
-        elif layout == "NCW":
-            return topi.generic.schedule_conv1d_nwc(outs)
-    raise ValueError("No compatible schedule")
-
-
+# conv1d
+reg.register_strategy("nn.conv1d", strategy.conv1d_strategy)
 reg.register_pattern("nn.conv1d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # conv2d
-def _find_conv2d_op(op):
-    """Find the op with conv2d in its tag by traversing."""
-    if 'conv2d' in op.tag:
-        return op
-    for tensor in op.input_tensors:
-        op_ = _find_conv2d_op(tensor.op)
-        if op_ is not None:
-            return op_
-    return None
-
-@reg.register_compute("nn.conv2d")
-def compute_conv2d(attrs, inputs, out_type, target):
-    """Compute definition of conv2d"""
-    padding = get_const_tuple(attrs.padding)
-    strides = get_const_tuple(attrs.strides)
-    dilation = get_const_tuple(attrs.dilation)
-    groups = attrs.groups
-    layout = attrs.data_layout
-    kernel_layout = attrs.kernel_layout
-    out_dtype = attrs.out_dtype
-    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                 else out_dtype)
-
-    assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
-    (dilation_h, dilation_w) = dilation
-    if dilation_h < 1 or dilation_w < 1:
-        raise ValueError("dilation should be positive value")
-
-    def _get_out_depth():
-        weight_shape = get_const_tuple(inputs[1].shape)
-        # NHWC layout
-        if kernel_layout.startswith("HW"):
-            return weight_shape[2] * weight_shape[3]
-        # NCHW layout.
-        # in ARM CPU contrib_spatial_pack schedule, we will prepack weight layout
-        if len(weight_shape) == 4:
-            return weight_shape[0] * weight_shape[1]
-        else:
-            assert len(weight_shape) == 5
-            C, M, _, _, VC = weight_shape
-            return C * VC * M
-
-    if groups == 1:
-        out = topi.nn.conv2d(
-            inputs[0], inputs[1], strides, padding,
-            dilation, layout, out_dtype)
-    elif layout == "NCHW" and _get_out_depth() == groups:
-        out = topi.nn.depthwise_conv2d_nchw(
-            inputs[0], inputs[1], strides, padding, dilation, out_dtype)
-    elif layout == "NHWC" and kernel_layout == "HWOI" and _get_out_depth() == groups:
-        out = topi.nn.depthwise_conv2d_nhwc(
-            inputs[0], inputs[1], strides, padding, dilation, out_dtype)
-    elif layout in ['NCHW', 'NCHW4c']:
-        out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
-                                        out_dtype)
-    else:
-        raise ValueError("not support arbitrary group number for now")
-    return [out]
-
-
-@reg.register_schedule("nn.conv2d")
-def schedule_conv2d(attrs, outs, target):
-    """Schedule definition of conv2d"""
-    groups = attrs.groups
-    layout = attrs.data_layout
-    kernel_layout = attrs.kernel_layout
-
-    with target:
-        if groups == 1 and layout == "NCHW":
-            return topi.generic.schedule_conv2d_nchw(outs)
-        elif groups == 1 and layout == "NCHW4c":
-            return topi.generic.schedule_conv2d_nchw(outs)
-        elif groups == 1 and layout == "NHWC":
-            return topi.generic.schedule_conv2d_nhwc(outs)
-        elif groups == 1 and layout == "HWCN":
-            return topi.generic.schedule_conv2d_hwcn(outs)
-        elif groups != 1:
-            # collect in_channels to distinguish depthwise and group conv2d
-            op = _find_conv2d_op(outs[0].op)
-            assert op is not None
-
-            is_depthwise = 'depthwise' in op.tag
-            if is_depthwise:
-                if layout == "NCHW":
-                    # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
-                    return topi.generic.schedule_depthwise_conv2d_nchw(outs)
-                if layout == "NHWC" and kernel_layout == "HWOI":
-                    return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
-            else:
-                if layout in ["NCHW", "NCHW4c"]:
-                    return topi.generic.schedule_group_conv2d_nchw(outs)
-    raise ValueError("No compatible schedule")
-
+reg.register_strategy("nn.conv2d", strategy.conv2d_strategy)
+reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 @reg.register_alter_op_layout("nn.conv2d")
-def alter_op_layout_conv2d(attrs, inputs, tinfos):
+def alter_op_layout_conv2d(attrs, inputs, tinfos, out_type):
     """Alternate the layout of conv2d"""
-    # pylint: disable=import-outside-toplevel
-    from ... import op
-    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
+    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
 
 @reg.register_legalize("nn.conv2d")
 def legalize_conv2d(attrs, inputs, types):
@@ -289,7 +115,6 @@ def legalize_conv2d(attrs, inputs, types):
     """
     return topi.nn.conv2d_legalize(attrs, inputs, types)
 
-
 @reg.register_convert_op_layout("nn.conv2d")
 def convert_conv2d(attrs, inputs, tinfos, desired_layout):
     """Convert Layout pass registration for conv2d op.
@@ -330,82 +155,10 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
             return relay.nn.conv2d(data, weight, **new_attrs)
     return None
 
-reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
-
 
 # conv2d_transpose
-@reg.register_compute("nn.conv2d_transpose")
-def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
-    """Compute definition of conv2d_transpose"""
-    padding = get_const_tuple(attrs.padding)
-    strides = get_const_tuple(attrs.strides)
-    dilation = get_const_tuple(attrs.dilation)
-    groups = attrs.groups
-    layout = attrs.data_layout
-    out_dtype = attrs.out_dtype
-    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                 else out_dtype)
-    assert layout == "NCHW", "only support nchw for now"
-    assert dilation == (1, 1), "not support dilate now"
-    assert groups == 1, "only support groups == 1 for now"
-    out = topi.nn.conv2d_transpose_nchw(
-        inputs[0], inputs[1], strides, padding, out_dtype)
-    output_padding = get_const_tuple(attrs.output_padding)
-    out = topi.nn.pad(out,
-                      [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]])
-    return [out]
-
-
-@reg.register_compute("nn.conv3d")
-def compute_conv3d(attrs, inputs, out_type, target):
-    """Compute definition of conv3d"""
-    padding = get_const_tuple(attrs.padding)
-    strides = get_const_tuple(attrs.strides)
-    dilation = get_const_tuple(attrs.dilation)
-    groups = attrs.groups
-    layout = attrs.data_layout
-    out_dtype = attrs.out_dtype
-    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                 else out_dtype)
-
-    assert layout in ["NCDHW", "NDHWC"]
-    (dilation_d, dilation_h, dilation_w) = dilation
-    if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
-        raise ValueError("dilation should be positive value")
-
-    if groups == 1:
-        out = topi.nn.conv3d(
-            inputs[0], inputs[1], strides, padding,
-            dilation, layout, out_dtype)
-    else:
-        raise ValueError("not support arbitrary group number for now")
-    return [out]
-
-
-@reg.register_schedule("nn.conv3d")
-def schedule_conv3d(attrs, outs, target):
-    """Schedule definition of conv3d"""
-    groups = attrs.groups
-    layout = attrs.data_layout
-
-    with target:
-        if groups == 1 and layout == "NCDHW":
-            return topi.generic.schedule_conv3d_ncdhw(outs)
-        elif groups == 1 and layout == "NDHWC":
-            return topi.generic.schedule_conv3d_ndhwc(outs)
-
-    raise ValueError("No compatible schedule")
-
-
-reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)
-
-
-@reg.register_schedule("nn.conv2d_transpose")
-def schedule_conv2d_transpose(attrs, outs, target):
-    """Schedule definition of conv2d_transpose"""
-    with target:
-        return topi.generic.schedule_conv2d_transpose_nchw(outs)
-
+reg.register_strategy("nn.conv2d_transpose", strategy.conv2d_transpose_strategy)
+reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 @reg.register_legalize("nn.conv2d_transpose")
 def legalize_conv2d_transpose(attrs, inputs, types):
@@ -427,202 +180,102 @@ def legalize_conv2d_transpose(attrs, inputs, types):
     """
     return topi.nn.conv2d_transpose_legalize(attrs, inputs, types)
 
-reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
-
-# conv1d_transpose
-@reg.register_compute("nn.conv1d_transpose")
-def compute_conv1d_transpose(attrs, inputs, out_dtype, target):
-    """Compute definition of conv1d_transpose"""
-    padding = get_const_tuple(attrs.padding)
-    strides = get_const_tuple(attrs.strides)
-    dilation = get_const_tuple(attrs.dilation)
-    groups = attrs.groups
-    layout = attrs.data_layout
-    out_dtype = attrs.out_dtype
-    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
-                 else out_dtype)
-    assert layout == "NCW", "conv1d_transpose ncw only supported"
-    assert dilation == (1,), "conv1d_transpose dilation is not supported"
-    assert groups == 1, "conv1d_transpose groups == 1 only supported"
-    out = topi.nn.conv1d_transpose_ncw(
-        inputs[0], inputs[1], strides, padding, out_dtype)
-    output_padding = get_const_tuple(attrs.output_padding)
-    out = topi.nn.pad(out,
-                      [0, 0, 0], [0, 0, output_padding[0]])
-    return [out]
 
+# conv3d
+reg.register_strategy("nn.conv3d", strategy.conv3d_strategy)
+reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
-@reg.register_schedule("nn.conv1d_transpose")
-def schedule_conv1d_transpose(attrs, outs, target):
-    """Schedule definition of conv1d_transpose"""
-    with target:
-        return topi.generic.schedule_conv1d_transpose_ncw(outs)
 
+# conv1d_transpose
+reg.register_strategy("nn.conv1d_transpose", strategy.conv1d_transpose_strategy)
 reg.register_pattern("nn.conv1d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
 
+
 # bias_add
-reg.register_schedule("nn.bias_add", schedule_injective)
+reg.register_injective_schedule("nn.bias_add")
 reg.register_pattern("nn.bias_add", OpPattern.BROADCAST)
 
 
 # max_pool1d
-@reg.register_schedule("nn.max_pool1d")
-def schedule_max_pool1d(attrs, outs, target):
-    """Schedule definition of max_pool1d"""
-    layout = attrs.layout
-    with target:
-        return topi.generic.schedule_pool(outs, layout)
-
-
+reg.register_schedule("nn.max_pool1d", strategy.schedule_pool)
 reg.register_pattern("nn.max_pool1d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # max_pool2d
-@reg.register_schedule("nn.max_pool2d")
-def schedule_max_pool2d(attrs, outs, target):
-    """Schedule definition of max_pool2d"""
-    layout = attrs.layout
-    with target:
-        return topi.generic.schedule_pool(outs, layout)
-
-
+reg.register_schedule("nn.max_pool2d", strategy.schedule_pool)
 reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # max_pool3d
-@reg.register_schedule("nn.max_pool3d")
-def schedule_max_pool3d(attrs, outs, target):
-    """Schedule definition of max_pool3d"""
-    layout = attrs.layout
-    with target:
-        return topi.generic.schedule_pool(outs, layout)
-
-
+reg.register_schedule("nn.max_pool3d", strategy.schedule_pool)
 reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # avg_pool1d
-@reg.register_schedule("nn.avg_pool1d")
-def schedule_avg_pool1d(attrs, outs, target):
-    """Schedule definition of avg_pool1d"""
-    layout = attrs.layout
-    with target:
-        return topi.generic.schedule_pool(outs, layout)
-
-
+reg.register_schedule("nn.avg_pool1d", strategy.schedule_pool)
 reg.register_pattern("nn.avg_pool1d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # avg_pool2d
-@reg.register_schedule("nn.avg_pool2d")
-def schedule_avg_pool2d(attrs, outs, target):
-    """Schedule definition of avg_pool2d"""
-    layout = attrs.layout
-    with target:
-        return topi.generic.schedule_pool(outs, layout)
-
+reg.register_schedule("nn.avg_pool2d", strategy.schedule_pool)
 reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # avg_pool3d
-@reg.register_schedule("nn.avg_pool3d")
-def schedule_avg_pool3d(attrs, outs, target):
-    """Schedule definition of avg_pool3d"""
-    layout = attrs.layout
-    with target:
-        return topi.generic.schedule_pool(outs, layout)
-
-
+reg.register_schedule("nn.avg_pool3d", strategy.schedule_pool)
 reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # max_pool2d_grad
-@reg.register_schedule("nn.max_pool2d_grad")
-def schedule_max_pool2d_grad(attrs, outs, target):
-    """Schedule definition of max_pool2d_grad"""
-    with target:
-        return topi.generic.schedule_pool_grad(outs)
-
-
+reg.register_schedule("nn.max_pool2d_grad", strategy.schedule_pool_grad)
 reg.register_pattern("nn.max_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # avg_pool2d_grad
-@reg.register_schedule("nn.avg_pool2d_grad")
-def schedule_avg_pool2d_grad(attrs, outs, target):
-    """Schedule definition of avg_pool2d_grad"""
-    with target:
-        return topi.generic.schedule_pool_grad(outs)
-
-
+reg.register_schedule("nn.avg_pool2d_grad", strategy.schedule_pool_grad)
 reg.register_pattern("nn.avg_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # global_max_pool2d
-@reg.register_schedule("nn.global_max_pool2d")
-def schedule_global_max_pool2d(_, outs, target):
-    """Schedule definition of global_max_pool2d"""
-    with target:
-        return topi.generic.schedule_adaptive_pool(outs)
-
-
+reg.register_schedule("nn.global_max_pool2d", strategy.schedule_adaptive_pool)
 reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # global_avg_pool2d
-@reg.register_schedule("nn.global_avg_pool2d")
-def schedule_global_avg_pool2d(_, outs, target):
-    """Schedule definition of global_avg_pool2d"""
-    with target:
-        return topi.generic.schedule_adaptive_pool(outs)
-
-
+reg.register_schedule("nn.global_avg_pool2d", strategy.schedule_adaptive_pool)
 reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # leaky_relu
-reg.register_schedule("nn.leaky_relu", schedule_broadcast)
+reg.register_broadcast_schedule("nn.leaky_relu")
 reg.register_pattern("nn.leaky_relu", OpPattern.ELEMWISE)
 
+
 # prelu
-reg.register_schedule("nn.prelu", schedule_broadcast)
+reg.register_broadcast_schedule("nn.prelu")
 reg.register_pattern("nn.prelu", OpPattern.BROADCAST)
 
+
 # flatten
-reg.register_schedule("nn.batch_flatten", schedule_broadcast)
+reg.register_broadcast_schedule("nn.batch_flatten")
 reg.register_pattern("nn.batch_flatten", OpPattern.INJECTIVE)
 
 
 # lrn
 @reg.register_compute("nn.lrn")
-def compute_lrn(attrs, inputs, out_dtype, target):
+def compute_lrn(attrs, inputs, out_dtype):
     """Compute definition of lrn"""
     assert len(inputs) == 1
     return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis,
                         attrs.alpha, attrs.beta, attrs.bias)]
 
-
-@reg.register_schedule("nn.lrn")
-def schedule_lrn(attrs, outs, target):
-    """Schedule definition of lrn"""
-    with target:
-        return topi.generic.schedule_lrn(outs)
-
-
+reg.register_schedule("nn.lrn", strategy.schedule_lrn)
 reg.register_pattern("nn.lrn", OpPattern.OPAQUE)
 
 
 # upsampling
-reg.register_schedule("nn.upsampling", reg.schedule_injective)
-
-
-def schedule_upsampling(_, outs, target):
-    """Schedule definition of upsampling"""
-    with target:
-        return topi.generic.schedule_injective(outs)
-
 @reg.register_compute("nn.upsampling")
-def compute_upsampling(attrs, inputs, out_dtype, target):
+def compute_upsampling(attrs, inputs, out_dtype):
     scale_h = attrs.scale_h
     scale_w = attrs.scale_w
     layout = attrs.layout
@@ -630,16 +283,12 @@ def compute_upsampling(attrs, inputs, out_dtype, target):
     align_corners = attrs.align_corners
     return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)]
 
-# upsampling3d
-reg.register_schedule("nn.upsampling3d", reg.schedule_injective)
+reg.register_injective_schedule("nn.upsampling")
 
-def schedule_upsampling3d(_, outs, target):
-    """Schedule definition of upsampling3d"""
-    with target:
-        return topi.generic.schedule_injective(outs)
 
+# upsampling3d
 @reg.register_compute("nn.upsampling3d")
-def compute_upsampling3d(attrs, inputs, out_dtype, target):
+def compute_upsampling3d(attrs, inputs, out_dtype):
     scale_d = attrs.scale_d
     scale_h = attrs.scale_h
     scale_w = attrs.scale_w
@@ -649,297 +298,93 @@ def compute_upsampling3d(attrs, inputs, out_dtype, target):
     return [topi.nn.upsampling3d(inputs[0], scale_d, scale_h, scale_w, layout, method,\
         coordinate_transformation_mode)]
 
+reg.register_injective_schedule("nn.upsampling3d")
+
+
 # pad
-reg.register_schedule("nn.pad", schedule_broadcast)
+reg.register_broadcast_schedule("nn.pad")
 
-# mirror_pad
-reg.register_schedule("nn.mirror_pad", schedule_broadcast)
 
+# mirror_pad
 @reg.register_compute("nn.mirror_pad")
-def compute_mirror_pad(attrs, inputs, out_dtype, target):
+def compute_mirror_pad(attrs, inputs, out_dtype):
     pad_before, pad_after = list(zip(*attrs.pad_width))
     mode = attrs.mode
     out = topi.nn.mirror_pad(inputs[0], pad_before=pad_before, pad_after=pad_after, mode=mode)
     return [out]
 
-# winograd related operators
-@reg.register_compute("nn.contrib_conv2d_winograd_without_weight_transform")
-def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_dtype, target):
-    """Compute definition of conv2d_winograd_without_weight_transform"""
-    # pylint: disable=assignment-from-no-return
-    padding = attrs.get_int_tuple("padding")
-    strides = attrs.get_int_tuple("strides")
-    dilation = attrs.get_int_tuple("dilation")
-    groups = attrs.get_int("groups")
-    data_layout = attrs.get_str("data_layout")
-    out_dtype = attrs.get_str("out_dtype")
-    tile_size = attrs.get_int("tile_size")
-    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
-    assert dilation == (1, 1), "Do not support dilate now"
-    assert groups == 1, "Do not supoort arbitrary group number"
-
-    out = topi.nn.conv2d_winograd_without_weight_transform(
-        inputs[0], inputs[1], strides, padding, dilation, data_layout,
-        out_dtype, tile_size)
-
-    return [out]
-
-
-@reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
-def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
-    """Schedule definition of conv2d_winograd_without_weight_transform"""
-    with target:
-        return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)
+reg.register_broadcast_schedule("nn.mirror_pad")
 
 
+# conv2d_winograd related operators
+reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform",
+                      strategy.conv2d_winograd_without_weight_transfrom_strategy)
 reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
                      OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 @reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
-def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target):
+def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype):
     """Compute definition of contrib_conv2d_winograd_weight_transform"""
     out = topi.nn.conv2d_winograd_weight_transform(
         inputs[0], attrs.get_int('tile_size'))
     return [out]
 
-
-@reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
-def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
-    """Schedule definition of contrib_conv2d_winograd_weight_transform"""
-    with target:
-        return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
-
-
+reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform",
+                      strategy.schedule_conv2d_winograd_weight_transform)
 reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
                      OpPattern.OUT_ELEMWISE_FUSABLE)
 
-
-# winograd nnpack related operators
-@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
-def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(
-        attrs, inputs, out_dtype, target):
-    """Compute definition of conv2d_winograd_nnpack_without_weight_transform"""
-    # pylint: disable=assignment-from-no-return
-    padding = attrs.get_int_tuple("padding")
-    strides = attrs.get_int_tuple("strides")
-    dilation = attrs.get_int_tuple("dilation")
-    groups = attrs.get_int("groups")
-    data_layout = attrs.get_str("data_layout")
-    out_dtype = attrs.get_str("out_dtype")
-    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
-    assert dilation == (1, 1), "Do not support dilate now"
-    assert groups == 1, "Do not supoort arbitrary group number"
-
-    # No bias
-    out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
-        inputs[0], inputs[1], None, strides, padding, dilation, data_layout,
-        out_dtype)
-
-    return [out]
-
-
-@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
-def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
-    """Schedule definition of conv2d_winograd_nnpack_without_weight_transform"""
-    with target:
-        return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)
-
-
-reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform",
-                     OpPattern.OPAQUE)
-
-
 @reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform")
-def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype, target):
+def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype):
     """Compute definition of contrib_conv2d_winograd_nnpack_weight_transform"""
     convolution_algorithm = attrs.get_int('convolution_algorithm')
     out = topi.nn.conv2d_winograd_nnpack_weight_transform(
         inputs[0], convolution_algorithm, out_dtype)
     return [out]
 
-
-@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
-def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
-    """Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform"""
-    with target:
-        return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
-
-
+reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform",
+                      strategy.schedule_conv2d_winograd_nnpack_weight_transform)
 reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
                      OpPattern.OPAQUE)
 
 
-@reg.register_compute("nn.contrib_conv2d_NCHWc")
-def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
-    """Compute definition of conv2d NCHWc"""
-    # pylint: disable=assignment-from-no-return
-    padding = attrs.get_int_tuple("padding")
-    strides = attrs.get_int_tuple("strides")
-    dilation = attrs.get_int_tuple("dilation")
-    data_layout = attrs.get_str("data_layout")
-    out_layout = attrs.get_str("out_layout")
-    out_dtype = attrs.get_str("out_dtype")
-    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
-
-    out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
-                               data_layout, out_layout, out_dtype)
-    return [out]
-
-
-@reg.register_schedule("nn.contrib_conv2d_NCHWc")
-def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
-    """Schedule definition of contrib_conv2d_NCHWc"""
-    with target:
-        return topi.generic.schedule_conv2d_NCHWc(outs)
-
-
+# conv2d_NCHWc
+reg.register_strategy("nn.contrib_conv2d_NCHWc", strategy.conv2d_NCHWc_strategy)
 reg.register_pattern("nn.contrib_conv2d_NCHWc",
                      OpPattern.OUT_ELEMWISE_FUSABLE)
 
-
-@reg.register_compute("nn.contrib_conv2d_NCHWc_int8")
-def compute_contrib_conv2d_NCHWc_int8(attrs, inputs, out_dtype, target):
-    """Compute definition of conv2d NCHWc"""
-    # pylint: disable=assignment-from-no-return
-    padding = attrs.get_int_tuple("padding")
-    strides = attrs.get_int_tuple("strides")
-    dilation = attrs.get_int_tuple("dilation")
-    data_layout = attrs.get_str("data_layout")
-    out_layout = attrs.get_str("out_layout")
-    out_dtype = attrs.get_str("out_dtype")
-    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
-
-    out = topi.nn.conv2d_NCHWc_int8(inputs[0], inputs[1], strides, padding, dilation,
-                                    data_layout, out_layout, out_dtype)
-    return [out]
-
-
-@reg.register_schedule("nn.contrib_conv2d_NCHWc_int8")
-def schedule_contrib_conv2d_NCHWc_int8(attrs, outs, target):
-    """Schedule definition of contrib_conv2d_NCHWc_int8"""
-    with target:
-        return topi.generic.schedule_conv2d_NCHWc_int8(outs)
-
-
-reg.register_pattern("nn.contrib_conv2d_NCHWc_int8",
-                     OpPattern.OUT_ELEMWISE_FUSABLE)
-
-
-@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
-def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
-    """Compute definition of depthwise conv2d NCHWc"""
-    # pylint: disable=assignment-from-no-return
-    padding = attrs.get_int_tuple("padding")
-    strides = attrs.get_int_tuple("strides")
-    dilation = attrs.get_int_tuple("dilation")
-    data_layout = attrs.get_str("data_layout")
-    out_layout = attrs.get_str("out_layout")
-    out_dtype = attrs.get_str("out_dtype")
-    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
-
-    out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
-                                         data_layout, out_layout, out_dtype)
-    return [out]
-
-
-@reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
-def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
-    """Schedule definition of contrib_conv2d_NCHWc"""
-    with target:
-        return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)
-
-
+# depthwise_conv2d_NCHWc
+reg.register_strategy("nn.contrib_depthwise_conv2d_NCHWc",
+                      strategy.depthwise_conv2d_NCHWc_strategy)
 reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
                      OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
-@reg.register_compute("nn.deformable_conv2d")
-def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
-    """Compute definition of deformable_conv2d"""
-    padding = get_const_tuple(attrs.padding)
-    strides = get_const_tuple(attrs.strides)
-    dilation = get_const_tuple(attrs.dilation)
-    deformable_groups = attrs.deformable_groups
-    groups = attrs.groups
-    out_dtype = attrs.out_dtype
-    out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
-    with target:
-        out = topi.nn.deformable_conv2d_nchw(inputs[0], inputs[1], inputs[2], strides, padding,
-                                             dilation, deformable_groups, groups, out_dtype)
-    return [out]
-
-
-@reg.register_schedule("nn.deformable_conv2d")
-def schedule_deformable_conv2d(attrs, outs, target):
-    """Schedule definition of deformable_conv2d"""
-    with target:
-        return topi.generic.schedule_deformable_conv2d_nchw(outs)
-
-
+# deformable_conv2d
+reg.register_strategy("nn.deformable_conv2d", strategy.deformable_conv2d_strategy)
 reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
+# bitpack
 @reg.register_compute("nn.bitpack")
-def compute_bitpack(attrs, inputs, out_dtype, target):
+def compute_bitpack(attrs, inputs, out_dtype):
     """Compute definition for bitpack"""
     bits = attrs.bits
     pack_axis = attrs.pack_axis
     bit_axis = attrs.bit_axis
     pack_type = attrs.pack_type
     name = attrs.name
-    with target:
-        out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type,
-                              name)
+    out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type, name)
     return [out]
 
-@reg.register_schedule("nn.bitpack")
-def schedule_bitpack(attrs, outs, target):
-    with target:
-        return topi.generic.schedule_bitpack(outs)
-
+reg.register_schedule("nn.bitpack", strategy.schedule_bitpack)
 reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE)
 
 
-@reg.register_compute("nn.bitserial_conv2d")
-def compute_bitserial_conv2d(attrs, inputs, out_dtype, target):
-    """Compute definition for bitserial conv2d."""
-    padding = get_const_tuple(attrs.padding)
-    strides = get_const_tuple(attrs.strides)
-    activation_bits = attrs.activation_bits
-    weight_bits = attrs.weight_bits
-    layout = attrs.data_layout
-    pack_dtype = attrs.pack_dtype
-    out_dtype = attrs.out_dtype
-    unipolar = attrs.unipolar
-    if layout == 'NCHW':
-        with target:
-            out = topi.nn.bitserial_conv2d_nchw(
-                inputs[0], inputs[1], strides, padding, activation_bits,
-                weight_bits, pack_dtype, out_dtype, unipolar)
-    elif layout == 'NHWC':
-        with target:
-            out = topi.nn.bitserial_conv2d_nhwc(
-                inputs[0], inputs[1], strides, padding, activation_bits,
-                weight_bits, pack_dtype, out_dtype, unipolar)
-    else:
-        raise ValueError("Data layout not supported.")
-
-    return [out]
-
-
-@reg.register_schedule("nn.bitserial_conv2d")
-def schedule_bitserial_conv2d(attrs, outs, target):
-    """Schedule definition for bitserial conv2d."""
-    layout = attrs.data_layout
-    if layout == 'NCHW':
-        with target:
-            return topi.generic.schedule_bitserial_conv2d_nchw(outs)
-    elif layout == 'NHWC':
-        with target:
-            return topi.generic.schedule_bitserial_conv2d_nhwc(outs)
-    else:
-        raise ValueError("Data layout not supported.")
+# bitserial_conv2d
+reg.register_strategy("nn.bitserial_conv2d", strategy.bitserial_conv2d_strategy)
+reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 @reg.register_legalize("nn.bitserial_conv2d")
 def legalize_bitserial_conv2d(attrs, inputs, types):
@@ -962,79 +407,58 @@ def legalize_bitserial_conv2d(attrs, inputs, types):
     return topi.nn.bitserial_conv2d_legalize(attrs, inputs, types)
 
 
-reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
-
-
 # bitserial_dense
-@reg.register_compute("nn.bitserial_dense")
-def compute_bitserial_dense(attrs, inputs, out_type, target):
-    """Compute definition of bitserial_dense"""
-    data_bits = attrs.data_bits
-    weight_bits = attrs.weight_bits
-    pack_dtype = attrs.pack_dtype
-    out_dtype = attrs.out_dtype
-    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
-    unipolar = attrs.unipolar
-    return [
-        topi.nn.bitserial_dense(
-            inputs[0],
-            inputs[1],
-            data_bits,
-            weight_bits,
-            pack_dtype,
-            out_dtype,
-            unipolar)
-    ]
-
-
-@reg.register_schedule("nn.bitserial_dense")
-def schedule_bitserial_dense(attrs, outputs, target):
-    """Schedule definition of bitserial_dense"""
-    with target:
-        return topi.generic.schedule_bitserial_dense(outputs)
-
-
+reg.register_strategy("nn.bitserial_dense", strategy.bitserial_dense_strategy)
 reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
-reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
-
+# cross_entropy
 @reg.register_compute("nn.cross_entropy")
-def compute_cross_entropy(attrs, inputs, out_dtype, target):
+def compute_cross_entropy(attrs, inputs, out_dtype):
     x, y = inputs
     return [-topi.sum(topi.log(x) * y) / x.shape[0]]
 
+reg.register_reduce_schedule("nn.cross_entropy")
+reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
 
-reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
 
+# cross_entropy_with_logits
 @reg.register_compute("nn.cross_entropy_with_logits")
-def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
+def compute_cross_entropy_with_logits(attrs, inputs, out_dtype):
     x, y = inputs
     return [-topi.sum(x * y) / x.shape[0]]
 
+reg.register_reduce_schedule("nn.cross_entropy_with_logits")
+reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
 
+
+# depth_to_space
 @reg.register_compute("nn.depth_to_space")
-def compute_depth_to_space(attrs, inputs, out_dtype, target):
+def compute_depth_to_space(attrs, inputs, out_dtype):
     block_size = attrs.block_size
     layout = attrs.layout
     mode = attrs.mode
     return [topi.nn.depth_to_space(inputs[0], block_size, layout=layout, mode=mode)]
 
-reg.register_schedule("nn.depth_to_space", schedule_injective)
+reg.register_injective_schedule("nn.depth_to_space")
 reg.register_pattern("nn.depth_to_space", OpPattern.INJECTIVE)
 
 
+# space_to_depth
 @reg.register_compute("nn.space_to_depth")
-def compute_space_to_depth(attrs, inputs, out_dtype, target):
+def compute_space_to_depth(attrs, inputs, out_dtype):
     block_size = attrs.block_size
     layout = attrs.layout
     return [topi.nn.space_to_depth(inputs[0], block_size, layout=layout)]
 
-reg.register_schedule("nn.space_to_depth", schedule_injective)
+reg.register_injective_schedule("nn.space_to_depth")
 reg.register_pattern("nn.space_to_depth", OpPattern.INJECTIVE)
 
 
-# shape func
+#####################
+#  Shape functions  #
+#####################
+
 @script
 def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
     out = output_tensor((dshape.shape[0],), "int64")
index 9ee4343..9ecb5af 100644 (file)
@@ -204,7 +204,6 @@ def conv2d(data,
     # TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
     # convert 2-way padding to 4-way padding
     padding = get_pad_tuple2d(padding)
-
     return _make.conv2d(data, weight, strides, padding, dilation,
                         groups, channels, kernel_size, data_layout,
                         kernel_layout, out_layout, out_dtype)
@@ -298,7 +297,6 @@ def conv3d(data,
         dilation = (dilation, dilation, dilation)
     if isinstance(padding, int):
         padding = (padding, padding, padding)
-
     return _make.conv3d(data, weight, strides, padding, dilation,
                         groups, channels, kernel_size, data_layout,
                         kernel_layout, out_layout, out_dtype)
@@ -1772,74 +1770,6 @@ def contrib_conv2d_winograd_without_weight_transform(data,
         kernel_layout, out_layout, out_dtype)
 
 
-def contrib_conv2d_winograd_nnpack_without_weight_transform(data,
-                                                            weight,
-                                                            strides=(1, 1),
-                                                            padding=(0, 0),
-                                                            dilation=(1, 1),
-                                                            groups=1,
-                                                            channels=None,
-                                                            kernel_size=None,
-                                                            data_layout="NCHW",
-                                                            kernel_layout="OIHW",
-                                                            out_layout="",
-                                                            out_dtype=""):
-    r"""2D convolution with the NNPACK implementation of winograd algorithm.
-
-    The basic parameters are the same as the ones in vanilla conv2d.
-    It assumes the weight is pre-transformed by nn.contrib_conv2d_winograd_nnpack_weight_transform
-
-    Parameters
-    ----------
-    data : tvm.relay.Expr
-        The input data to the operator.
-
-    weight : tvm.relay.Expr
-        The weight expressions.
-
-    strides : tuple of int, optional
-        The strides of convolution.
-
-    padding : tuple of int, optional
-        The padding of convolution on both sides of inputs before convolution.
-
-    dilation : tuple of int, optional
-        Specifies the dilation rate to be used for dilated convolution.
-
-    groups : int, optional
-        Number of groups for grouped convolution.
-
-    channels : int, optional
-        Number of output channels of this convolution.
-
-    kernel_size : tuple of int, optional
-        The spatial of the convolution kernel.
-
-    data_layout : str, optional
-        Layout of the input.
-
-    kernel_layout : str, optional
-        Layout of the weight.
-
-    out_layout : str, optional
-        Layout of the output, by default, out_layout is the same as data_layout
-
-    out_dtype : str, optional
-        Specifies the output data type for mixed precision conv2d.
-
-    Returns
-    -------
-    result : tvm.relay.Expr
-        The computed result.
-    """
-    # convert 2-way padding to 4-way padding
-    padding = get_pad_tuple2d(padding)
-    return _make.contrib_conv2d_winograd_nnpack_without_weight_transform(
-        data, weight, strides, padding, dilation,
-        groups, channels, kernel_size, data_layout,
-        kernel_layout, out_layout, out_dtype)
-
-
 def contrib_conv2d_nchwc(data,
                          kernel,
                          strides=(1, 1),
@@ -1974,73 +1904,6 @@ def contrib_depthwise_conv2d_nchwc(data,
                                                 groups, channels, kernel_size, data_layout,
                                                 kernel_layout, out_layout, out_dtype)
 
-def contrib_conv2d_nchwc_int8(data,
-                              kernel,
-                              strides=(1, 1),
-                              padding=(0, 0),
-                              dilation=(1, 1),
-                              groups=1,
-                              channels=None,
-                              kernel_size=None,
-                              data_layout="NCHW8c",
-                              kernel_layout="OIHW",
-                              out_layout="",
-                              out_dtype=""):
-    r"""Variant of 2D convolution. It deals with only int8 inputs.
-
-    This operator takes the weight as the convolution kernel
-    and convolves it with data to produce an output, following a specialized
-    NCHWc data layout.
-
-    Parameters
-    ----------
-    data : tvm.relay.Expr
-        The input data to the operator.
-
-    kernel : tvm.relay.Expr
-        The kernel expressions.
-
-    strides : tuple of int, optional
-        The strides of convolution.
-
-    padding : tuple of int, optional
-        The padding of convolution on both sides of inputs before convolution.
-
-    dilation : tuple of int, optional
-        Specifies the dilation rate to be used for dilated convolution.
-
-    groups : int, optional
-        Number of groups for grouped convolution.
-
-    channels : int, optional
-        Number of output channels of this convolution.
-
-    kernel_size : tuple of int, optional
-        The spatial of the convolution kernel.
-
-    data_layout : str, optional
-        Layout of the input.
-
-    kernel_layout : str, optional
-        Layout of the weight.
-
-    out_layout : str, optional
-        Layout of the output, by default, out_layout is the same as data_layout
-
-    out_dtype : str, optional
-        Specifies the output data type for mixed precision conv2d.
-
-    Returns
-    -------
-    result : tvm.relay.Expr
-        The computed result.
-    """
-    # convert 2-way padding to 4-way padding
-    padding = get_pad_tuple2d(padding)
-    return _make.contrib_conv2d_NCHWc_int8(data, kernel, strides, padding, dilation,
-                                           groups, channels, kernel_size, data_layout,
-                                           kernel_layout, out_layout, out_dtype)
-
 
 def contrib_conv2d_winograd_weight_transform(weight,
                                              tile_size):
index c6d3012..4fd88f4 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=unused-argument
+#pylint: disable=unused-argument,invalid-name
 """The base node types for the Relay language."""
-import topi
 import tvm._ffi
 from tvm.driver import lower, build
 
 from ..base import register_relay_node
 from ..expr import RelayExpr
 from ...api import register_func
+from ...target import get_native_generic_func, GenericFunc
+from ...runtime import Object
 from . import _make
 
 @register_relay_node
@@ -143,39 +144,208 @@ class OpPattern(object):
     OPAQUE = 8
 
 
-def register_schedule(op_name, schedule=None, level=10):
-    """Register schedule function for an op
+@tvm._ffi.register_object("relay.OpImplementation")
+class OpImplementation(Object):
+    """Operator implementation"""
+    def compute(self, attrs, inputs, out_type):
+        """Call compute function.
+
+        Parameters
+        ----------
+        attrs : Attrs
+            Op attributes.
+
+        inputs : list[tvm.tensor.Tensor]
+            The input tensors.
+
+        out_type : relay.Type
+            The output type.
+
+        Returns
+        -------
+        outs : list[tvm.tensor.Tensor]
+            The output tensors.
+        """
+        return _OpImplementationCompute(self, attrs, inputs, out_type)
+
+    def schedule(self, attrs, outs, target):
+        """Call schedule function.
+
+        Parameters
+        ----------
+        attrs : Attrs
+            Op attributes.
+
+        outs : list[tvm.tensor.Tensor]
+            The output tensors.
+
+        target : tvm.target.Target
+            The target to schedule the op.
+
+        Returns
+        -------
+        schedule : tvm.Schedule
+            The schedule.
+        """
+        return _OpImplementationSchedule(self, attrs, outs, target)
+
+
+@tvm._ffi.register_object("relay.OpSpecialization")
+class OpSpecialization(Object):
+    """Operator specialization"""
+
+
+@tvm._ffi.register_object("relay.OpStrategy")
+class OpStrategy(Object):
+    """Operator strategy"""
+    def __init__(self):
+        self.__init_handle_by_constructor__(_make.OpStrategy)
+
+    def add_implementation(self, compute, schedule, name="default", plevel=10):
+        """Add an implementation to the strategy
+
+        Parameters
+        ----------
+        compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type)
+                           -> List[Tensor]
+            The compute function.
+
+        schedule : function (attrs: Attrs, outs: List[Tensor], target:Target) -> Schedule
+            The schedule function.
+
+        name : str
+            The name of implementation.
+
+        plevel : int
+            The priority level of implementation.
+        """
+        _OpStrategyAddImplementation(self, compute, schedule, name, plevel)
+
+
+def _wrap_default_fstrategy(compute, schedule, name):
+    def _fstrategy(attrs, inputs, out_type, target):
+        strategy = OpStrategy()
+        strategy.add_implementation(compute, schedule, name=name)
+        return strategy
+    return _fstrategy
+
+
+def _create_fstrategy_from_schedule(op_name, schedule):
+    assert hasattr(schedule, "dispatch_dict")
+    compute = get(op_name).get_attr("FTVMCompute")
+    assert compute is not None, "FTVMCompute is not registered for op %s" % op_name
+    fstrategy = get_native_generic_func("{}_strategy".format(op_name))
+    name_pfx = schedule.__name__
+    name_pfx = name_pfx[name_pfx.index('_')+1:]
+    fstrategy.set_default(
+        _wrap_default_fstrategy(compute, schedule.fdefault, "%s.generic" % name_pfx))
+    for key, sch in schedule.dispatch_dict.items():
+        fstrategy.register(
+            _wrap_default_fstrategy(compute, sch, "%s.%s" % (name_pfx, key)), [key])
+    return fstrategy
+
+
+def register_compute(op_name, compute=None, level=10):
+    """Register compute function for an op.
 
     Parameters
     ----------
     op_name : str
         The name of the op.
 
-    schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule
-        The schedule function.
+    compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type)
+                       -> List[Tensor]
+        The compute function.
 
     level : int
         The priority level
     """
-    return register(op_name, "FTVMSchedule", schedule, level)
+    return register(op_name, "FTVMCompute", compute, level)
 
 
-def register_compute(op_name, compute=None, level=10):
-    """Register compute function for an op.
+def register_strategy(op_name, fstrategy=None, level=10):
+    """Register strategy function for an op.
 
     Parameters
     ----------
     op_name : str
         The name of the op.
 
-    compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target)
-                       -> List[Tensor]
-        The compute function.
+    fstrategy : function (attrs: Attrs, inputs: List[Tensor], out_type: Type,
+                          target:Target) -> OpStrategy
+        The strategy function. Need to be native GenericFunc.
 
     level : int
         The priority level
     """
-    return register(op_name, "FTVMCompute", compute, level)
+    if not isinstance(fstrategy, GenericFunc):
+        assert hasattr(fstrategy, "generic_func_node")
+        fstrategy = fstrategy.generic_func_node
+    return register(op_name, "FTVMStrategy", fstrategy, level)
+
+
+def register_schedule(op_name, schedule, level=10):
+    """Register schedule function for an op.
+
+    This is used when compute function is the same for all targets and only
+    schedule is different. It requires FTVMCompute is already registered to
+    the op.
+
+    Parameters
+    ----------
+    op_name : str
+        The name of the op.
+
+    schedule : function (attrs: Attrs, outs: List[Tensor], target:Target) -> Schedule
+        The schedule function. Need to be target.generic_func.
+
+    level : int
+        The priority level
+    """
+    fstrategy = _create_fstrategy_from_schedule(op_name, schedule)
+    return register_strategy(op_name, fstrategy, level)
+
+
+def register_injective_schedule(op_name, level=10):
+    """Register injective schedule function for an op.
+
+    Parameters
+    ----------
+    op_name : str
+        The name of the op.
+
+    level : int
+        The priority level
+    """
+    return register_schedule(op_name, _schedule_injective, level)
+
+
+def register_broadcast_schedule(op_name, level=10):
+    """Register broadcast schedule function for an op.
+
+    Parameters
+    ----------
+    op_name : str
+        The name of the op.
+
+    level : int
+        The priority level
+    """
+    return register_schedule(op_name, _schedule_injective, level)
+
+
+def register_reduce_schedule(op_name, level=10):
+    """Register reduce schedule function for an op.
+
+    Parameters
+    ----------
+    op_name : str
+        The name of the op.
+
+    level : int
+        The priority level
+    """
+    return register_schedule(op_name, _schedule_reduce, level)
 
 
 def register_alter_op_layout(op_name, alter_layout=None, level=10):
@@ -245,6 +415,7 @@ def register_pattern(op_name, pattern, level=10):
     """
     return register(op_name, "TOpPattern", pattern, level)
 
+
 def register_gradient(op_name, fgradient=None, level=10):
     """Register operator pattern for an op.
 
@@ -261,6 +432,7 @@ def register_gradient(op_name, fgradient=None, level=10):
     """
     return register(op_name, "FPrimalGradient", fgradient, level)
 
+
 def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
     """Register operator shape function for an op.
 
@@ -290,18 +462,8 @@ def _lower(name, schedule, inputs, outputs):
 def _build(lowered_funcs):
     return build(lowered_funcs, target="llvm")
 
-
-def schedule_injective(attrs, outputs, target):
-    """Generic schedule for binary broadcast."""
-    with target:
-        return topi.generic.schedule_injective(outputs)
-
-
-def schedule_concatenate(attrs, outputs, target):
-    """Generic schedule for concatinate."""
-    with target:
-        return topi.generic.schedule_concatenate(outputs)
-
+_schedule_injective = None
+_schedule_reduce = None
 
 __DEBUG_COUNTER__ = 0
 
diff --git a/python/tvm/relay/op/strategy/__init__.py b/python/tvm/relay/op/strategy/__init__.py
new file mode 100644 (file)
index 0000000..59adf82
--- /dev/null
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=wildcard-import
+"""Relay op strategies."""
+from __future__ import absolute_import as _abs
+
+from .generic import *
+from . import x86
+from . import arm_cpu
+from . import cuda
+from . import hls
+from . import mali
+from . import bifrost
+from . import opengl
+from . import rocm
+from . import intel_graphics
diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py
new file mode 100644 (file)
index 0000000..0945f51
--- /dev/null
@@ -0,0 +1,231 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of ARM CPU operator strategy."""
+# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
+import re
+import logging
+
+import topi
+from .generic import *
+from .. import op as _op
+
+logger = logging.getLogger('strategy')
+
+@schedule_injective.register("arm_cpu")
+def schedule_injective_arm_cpu(_, outs, target):
+    """schedule injective ops for arm cpu"""
+    with target:
+        return topi.arm_cpu.schedule_injective(outs)
+
+@schedule_concatenate.register("arm_cpu")
+def schedule_concatenate_arm_cpu(_, outs, target):
+    """schedule concatenate for arm cpu"""
+    with target:
+        return topi.arm_cpu.schedule_concatenate(outs)
+
+@conv2d_strategy.register("arm_cpu")
+def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
+    """conv2d arm cpu strategy"""
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    dilation_h, dilation_w = attrs.get_int_tuple("dilation")
+    stride_h, stride_w = attrs.get_int_tuple("strides")
+    padding = attrs.get_int_tuple("padding")
+    groups = attrs.groups
+    layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if groups == 1:
+        if layout == "NCHW":
+            if kernel_layout == "OIHW":
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
+                    wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
+                    name="conv2d_nchw_spatial_pack.arm_cpu")
+                # check if winograd algorithm is applicable
+                _, _, kh, kw = get_const_tuple(kernel.shape)
+                pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw))
+                if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
+                    dilation_h == 1 and dilation_w == 1:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd),
+                        wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd),
+                        name="conv2d_nchw_winograd.arm_cpu",
+                        plevel=15)
+                    if "nnpack" in target.libs and pt == 1 and pb == 1 and pl == 1 and pr == 1:
+                        strategy.add_implementation(
+                            wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd_nnpack),
+                            wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack),
+                            name="conv2d_nchw_winograd_nnpack.arm_cpu",
+                            plevel=13)
+            elif re.match(r"OIHW\d*o", kernel_layout):
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
+                    wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
+                    name="conv2d_nchw_spatial_pack.arm_cpu")
+            else:
+                raise RuntimeError("Unsupported weight layout {} for conv2d NCHW".
+                                   format(kernel_layout))
+        elif layout == "HWCN":
+            assert kernel_layout == "HWIO"
+            logger.warning("conv2d_hwcn is not optimized for arm cpu.")
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_hwcn),
+                wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn),
+                name="conv2d_hwcn.generic")
+        elif layout == "NHWC":
+            assert kernel_layout == "HWIO"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
+                wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
+                name="conv2d_nhwc_spatial_pack.arm_cpu")
+        else:
+            raise RuntimeError("Unsupported conv2d layout {} for arm cpu".format(layout))
+    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW" or re.match(r"OIHW\d*o", kernel_layout)
+            if kernel_layout == "OIHW":
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw),
+                    wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw),
+                    name="depthwise_conv2d_nchw.arm_cpu")
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw_spatial_pack),
+                wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack),
+                name="depthwise_conv2d_nchw_spatial_pack.arm_cpu",
+                plevel=15)
+        elif layout == "NHWC":
+            assert kernel_layout == "HWOI"
+            logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.")
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
+                wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
+                name="depthwise_conv2d_nhwc.generic")
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d layout {} for arm cpu".
+                               format(layout))
+    else: # group_conv2d
+        if layout == 'NCHW':
+            assert kernel_layout == "OIHW"
+            logger.warning("group_conv2d with layout NCHW is not optimized for arm cpu.")
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True),
+                wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw),
+                name="group_conv2d_nchw.generic")
+        else:
+            raise RuntimeError("Unsupported group_conv2d layout {} for arm cpu".
+                               format(layout))
+    return strategy
+
+def wrap_compute_conv2d_winograd_nnpack(topi_compute):
+    """wrap topi compute for conv2d_winograd NNPack"""
+    def _compute_conv2d_nnpack(attrs, inputs, out_type):
+        padding = attrs.get_int_tuple("padding")
+        strides = attrs.get_int_tuple("strides")
+        dilation = attrs.get_int_tuple("dilation")
+        out_dtype = attrs.get_str("out_dtype")
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
+        return [topi_compute(inputs[0], inputs[1], None, strides, padding,
+                             dilation, out_dtype)]
+    return _compute_conv2d_nnpack
+
+@conv2d_winograd_without_weight_transfrom_strategy.register("arm_cpu")
+def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out_type, target):
+    """conv2d_winograd_without_weight_transfrom arm cpu strategy"""
+    dilation = attrs.get_int_tuple("dilation")
+    groups = attrs.get_int("groups")
+    layout = attrs.data_layout
+    strides = attrs.get_int_tuple("strides")
+    kernel = inputs[1]
+    assert dilation == (1, 1), "Do not support dilate now"
+    assert strides == (1, 1), "Do not support strides now"
+    assert groups == 1, "Do not supoort arbitrary group number"
+    strategy = _op.OpStrategy()
+    if layout == "NCHW":
+        if len(kernel.shape) == 5:
+            pad_kh, pad_kw, _, _, _ = get_const_tuple(inputs[1].shape)
+            tile_size = attrs.get_int("tile_size")
+            kh = pad_kh - tile_size + 1
+            kw = pad_kw - tile_size + 1
+            assert kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd),
+                wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd),
+                name="conv2d_nchw_winograd.arm_cpu")
+        elif len(kernel.shape) == 4:
+            # kernel must be packed by winograd nnpack
+            assert "nnpack" in target.libs
+            strategy.add_implementation(
+                wrap_compute_conv2d_winograd_nnpack(
+                    topi.arm_cpu.conv2d_nchw_winograd_nnpack_without_weight_transform),
+                wrap_topi_schedule(
+                    topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack_without_weight_transform),
+                name="conv2d_nchw_winograd_nnpack_withou_weight_transform.arm_cpu",
+                plevel=5)
+        else:
+            raise RuntimeError("Unsupported kernel shape: {}".format(kernel.shape))
+    else:
+        raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
+                           format(layout))
+    return strategy
+
+@conv2d_transpose_strategy.register("arm_cpu")
+def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target):
+    """conv2d_transpose arm cpu strategy"""
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    assert layout == "NCHW", "only support nchw for now"
+    assert dilation == (1, 1), "not support dilate now"
+    assert groups == 1, "only support groups == 1 for now"
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_conv2d_transpose(topi.arm_cpu.conv2d_transpose_nchw),
+        wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_transpose_nchw),
+        name="conv2d_tranpose_nchw.arm_cpu")
+    return strategy
+
+@bitserial_conv2d_strategy.register("arm_cpu")
+def bitserial_conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
+    """bitserial_conv2d x86 strategy"""
+    strategy = _op.OpStrategy()
+    layout = attrs.data_layout
+    if layout == "NCHW":
+        strategy.add_implementation(
+            wrap_compute_bitserial_conv2d(topi.x86.bitserial_conv2d_nchw),
+            wrap_topi_schedule(topi.x86.schedule_bitserial_conv2d_nchw),
+            name="bitserial_conv2d_nchw.arm_cpu")
+    elif layout == "NHWC":
+        strategy.add_implementation(
+            wrap_compute_bitserial_conv2d(topi.arm_cpu.bitserial_conv2d_nhwc),
+            wrap_topi_schedule(topi.arm_cpu.schedule_bitserial_conv2d_nhwc),
+            name="bitserial_conv2d_nhwc.arm_cpu")
+    else:
+        raise ValueError("Data layout {} not supported.".format(layout))
+    return strategy
+
+@bitserial_dense_strategy.register("arm_cpu")
+def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target):
+    """bitserial_dense arm cpu strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_bitserial_dense(topi.arm_cpu.bitserial_dense),
+        wrap_topi_schedule(topi.arm_cpu.schedule_bitserial_dense),
+        name="bitserial_dense.arm_cpu")
+    return strategy
diff --git a/python/tvm/relay/op/strategy/bifrost.py b/python/tvm/relay/op/strategy/bifrost.py
new file mode 100644 (file)
index 0000000..e8f6298
--- /dev/null
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of bifrost operator strategy."""
+# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
+import re
+import topi
+from .generic import *
+from .. import op as _op
+
+
+@conv2d_strategy.register("bifrost")
+def conv2d_strategy_bifrost(attrs, inputs, out_type, target):
+    """conv2d mali(bifrost) strategy"""
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    dilation_h, dilation_w = attrs.get_int_tuple("dilation")
+    stride_h, stride_w = attrs.get_int_tuple("strides")
+    groups = attrs.groups
+    layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if groups == 1:
+        if layout == "NCHW":
+            if kernel_layout == "OIHW":
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.bifrost.conv2d_nchw_spatial_pack),
+                    wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack),
+                    name="conv2d_nchw_spatial_pack.bifrost")
+
+                _, _, kh, kw = get_const_tuple(kernel.shape)
+                if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
+                   dilation_h == 1 and dilation_w == 1:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.bifrost.conv2d_nchw_winograd),
+                        wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_winograd),
+                        name="conv2d_nchw_winograd.bifrost",
+                        plevel=15)
+            elif re.match(r"OIHW\d*o", kernel_layout):
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.bifrost.conv2d_nchw_spatial_pack),
+                    wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack),
+                    name="conv2d_nchw_spatial_pack.bifrost")
+        else:
+            raise RuntimeError("Unsupported conv2d layout {} for Mali(Bifrost)".
+                               format(layout))
+    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
+                wrap_topi_schedule(topi.bifrost.schedule_depthwise_conv2d_nchw),
+                name="depthwise_conv2d_nchw.bifrost")
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d layout {} for Mali(Bifrost)".
+                               format(layout))
+    else: # group_conv2d
+        raise RuntimeError("group_conv2d is not supported for Mali(Bifrost)")
+    return strategy
+
+@conv2d_winograd_without_weight_transfrom_strategy.register("bifrost")
+def conv2d_winograd_without_weight_transfrom_strategy_bifrost(attrs, inputs, out_type, target):
+    """conv2d_winograd_without_weight_transfrom mali(bifrost) strategy"""
+    dilation = attrs.get_int_tuple("dilation")
+    groups = attrs.get_int("groups")
+    layout = attrs.data_layout
+    strides = attrs.get_int_tuple("strides")
+    assert dilation == (1, 1), "Do not support dilate now"
+    assert strides == (1, 1), "Do not support strides now"
+    assert groups == 1, "Do not supoort arbitrary group number"
+    strategy = _op.OpStrategy()
+    if layout == "NCHW":
+        strategy.add_implementation(
+            wrap_compute_conv2d(topi.bifrost.conv2d_nchw_winograd),
+            wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_winograd),
+            name="conv2d_nchw_winograd.bifrost")
+    else:
+        raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
+                           format(layout))
+    return strategy
+
+@dense_strategy.register("bifrost")
+def dense_strategy_bifrost(attrs, inputs, out_type, target):
+    """dense mali(bifrost) strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_dense(topi.bifrost.dense),
+                                wrap_topi_schedule(topi.bifrost.schedule_dense),
+                                name="dense.bifrost")
+    return strategy
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
new file mode 100644 (file)
index 0000000..b2f559f
--- /dev/null
@@ -0,0 +1,398 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of CUDA/GPU operator strategy."""
+# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
+import topi
+from tvm.te import SpecializedCondition
+from .generic import *
+from .. import op as _op
+
+@schedule_injective.register(["cuda", "gpu"])
+def schedule_injective_cuda(attrs, outs, target):
+    """schedule injective ops for cuda"""
+    with target:
+        return topi.cuda.schedule_injective(outs)
+
+@schedule_reduce.register(["cuda", "gpu"])
+def schedule_reduce_cuda(attrs, outs, target):
+    """schedule reduction ops for cuda"""
+    with target:
+        return topi.cuda.schedule_reduce(outs)
+
+@schedule_concatenate.register(["cuda", "gpu"])
+def schedule_concatenate_cuda(attrs, outs, target):
+    """schedule concatenate for cuda"""
+    with target:
+        return topi.cuda.schedule_injective(outs)
+
+@schedule_pool.register(["cuda", "gpu"])
+def schedule_pool_cuda(attrs, outs, target):
+    """schedule pooling ops for cuda"""
+    with target:
+        return topi.cuda.schedule_pool(outs, attrs.layout)
+
+@schedule_pool_grad.register(["cuda", "gpu"])
+def schedule_pool_grad_cuda(attrs, outs, target):
+    """schedule pooling gradient ops for cuda"""
+    with target:
+        return topi.cuda.schedule_pool_grad(outs)
+
+@schedule_adaptive_pool.register(["cuda", "gpu"])
+def schedule_adaptive_pool_cuda(attrs, outs, target):
+    """schedule adaptive pooling ops for cuda"""
+    with target:
+        return topi.cuda.schedule_adaptive_pool(outs)
+
+@schedule_softmax.register(["cuda", "gpu"])
+def schedule_softmax_cuda(attrs, outs, target):
+    """schedule softmax for cuda"""
+    with target:
+        return topi.cuda.schedule_softmax(outs)
+
+@schedule_lrn.register(["cuda", "gpu"])
+def schedule_lrn_cuda(attrs, outs, target):
+    """schedule LRN for cuda"""
+    with target:
+        return topi.cuda.schedule_lrn(outs)
+
+@conv2d_strategy.register(["cuda", "gpu"])
+def conv2d_strategy_cuda(attrs, inputs, out_type, target):
+    """conv2d cuda strategy"""
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    stride_h, stride_w = attrs.get_int_tuple("strides")
+    dilation_h, dilation_w = attrs.get_int_tuple("dilation")
+    padding = attrs.get_int_tuple("padding")
+    groups = attrs.groups
+    layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if groups == 1:
+        if layout == "NCHW":
+            # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.conv2d_nchw),
+                wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
+                name="conv2d_nchw.cuda")
+            _, _, kh, kw = get_const_tuple(kernel.shape)
+            if 2 < kh < 8 and 2 < kw < 8 and kh == kw and stride_h == 1 and stride_w == 1 and \
+                dilation_h == 1 and dilation_w == 1:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
+                    wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
+                    name="conv2d_nchw_winograd.cuda",
+                    plevel=15)
+        elif layout == "HWCN":
+            assert kernel_layout == "HWIO"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
+                wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
+                name="conv2d_hwcn.cuda")
+        # TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
+        # elif layout == "NHWC":
+        #     assert kernel_layout == "HWIO"
+        #     strategy.add_implementation(
+        #         wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
+        #         wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
+        #         name="conv2d_nhwc.cuda")
+        elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
+            assert kernel_layout == "OIHW4o4i"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
+                wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
+                name="conv2d_NCHWc_int8.cuda")
+        else:
+            raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
+        # add cudnn implementation
+        if target.target_name == "cuda" and "cudnn" in target.libs:
+            if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
+                    padding[1] == padding[3]:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.cuda.conv2d_cudnn, True),
+                    wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
+                    name="conv2d_cudnn.cuda",
+                    plevel=5)
+    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
+                wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
+                name="dpethwise_nchw.cuda")
+        elif layout == "NHWC":
+            assert kernel_layout == "HWOI"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
+                wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
+                name="depthwise_conv2d_nhwc.cuda")
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
+    else: # group_conv2d
+        if layout == 'NCHW':
+            # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
+                wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
+                name="group_conv2d_nchw.cuda")
+        elif layout == 'NCHW4c' and data.dtype in ["int8", "uint8"]:
+            assert kernel_layout == "OIHW4o4i"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
+                wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
+                name="group_conv2d_NCHWc_int8.cuda")
+        else:
+            raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
+    return strategy
+
+@conv2d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"])
+def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target):
+    """conv2d_winograd_without_weight_transfrom cuda strategy"""
+    dilation = attrs.get_int_tuple("dilation")
+    groups = attrs.get_int("groups")
+    layout = attrs.data_layout
+    assert dilation == (1, 1), "Do not support dilate now"
+    assert groups == 1, "Do not supoort arbitrary group number"
+    strategy = _op.OpStrategy()
+    if layout == "NCHW":
+        strategy.add_implementation(
+            wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd_without_weight_transform),
+            wrap_topi_schedule(
+                topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform),
+            name="conv2d_nchw_winograd_without_weight_transform.cuda")
+    else:
+        raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
+                           format(layout))
+    return strategy
+
+@deformable_conv2d_strategy.register(["cuda", "gpu"])
+def deformable_conv2d_strategy_cuda(attrs, inputs, out_type, target):
+    """deformable_conv2d cuda strategy"""
+    layout = attrs.data_layout
+    assert layout == "NCHW"
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_deformable_conv2d(topi.cuda.deformable_conv2d_nchw),
+        wrap_topi_schedule(topi.cuda.schedule_deformable_conv2d_nchw),
+        name="deformable_conv2d_nchw.cuda")
+    return strategy
+
+@conv2d_transpose_strategy.register(["cuda", "gpu"])
+def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
+    """conv2d_transpose cuda strategy"""
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    assert layout == "NCHW", "only support nchw for now"
+    assert dilation == (1, 1), "not support dilate now"
+    assert groups == 1, "only support groups == 1 for now"
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_nchw),
+        wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw),
+        name="conv2d_transpose_nchw.cuda")
+    return strategy
+
+@conv3d_strategy.register(["cuda", "gpu"])
+def conv3d_strategy_cuda(attrs, inputs, out_type, target):
+    """conv3d cuda strategy"""
+    strategy = _op.OpStrategy()
+    layout = attrs.data_layout
+    assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout)
+    if layout == "NCDHW":
+        strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
+                                    wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
+                                    name="conv3d_ncdhw.cuda",
+                                    plevel=10)
+    else: # layout == "NDHWC":
+        strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
+                                    wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
+                                    name="conv3d_ndhwc.cuda",
+                                    plevel=10)
+    if target.target_name == "cuda" and "cudnn" in target.libs:
+        strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
+                                    wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
+                                    name="conv3d_cudnn.cuda",
+                                    plevel=15)
+    return strategy
+
+@conv1d_strategy.register(["cuda", "gpu"])
+def conv1d_strategy_cuda(attrs, inputs, out_type, target):
+    """conv1d cuda strategy"""
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    if dilation[0] < 1:
+        raise ValueError("dilation should be a positive value")
+    strategy = _op.OpStrategy()
+    if layout == "NCW":
+        strategy.add_implementation(wrap_compute_conv1d(topi.cuda.conv1d_ncw),
+                                    wrap_topi_schedule(topi.cuda.schedule_conv1d_ncw),
+                                    name="conv1d_ncw.cuda")
+    elif layout == "NWC":
+        strategy.add_implementation(wrap_compute_conv1d(topi.cuda.conv1d_nwc),
+                                    wrap_topi_schedule(topi.cuda.schedule_conv1d_nwc),
+                                    name="conv1d_nwc.cuda")
+    else:
+        raise ValueError("Unsupported conv1d layout {}".format(layout))
+    return strategy
+
+@conv1d_transpose_strategy.register(["cuda", "gpu"])
+def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
+    """conv1d_transpose cuda strategy"""
+    strategy = _op.OpStrategy()
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    assert layout == "NCW", "conv1d_transpose ncw only supported"
+    assert dilation == (1,), "conv1d_transpose dilation is not supported"
+    assert groups == 1, "conv1d_transpose groups == 1 only supported"
+    strategy.add_implementation(wrap_compute_conv1d_transpose(topi.cuda.conv1d_transpose_ncw),
+                                wrap_topi_schedule(topi.cuda.schedule_conv1d_transpose_ncw),
+                                name="conv1d_transpose_ncw.cuda")
+    return strategy
+
+@dense_strategy.register(["cuda", "gpu"])
+def dense_strategy_cuda(attrs, inputs, out_type, target):
+    """dense cuda strategy"""
+    strategy = _op.OpStrategy()
+    if out_type.dtype == "int8":
+        strategy.add_implementation(
+            wrap_compute_dense(topi.cuda.dense_int8),
+            wrap_topi_schedule(topi.cuda.schedule_dense_int8),
+            name="dense_int8.cuda")
+    else:
+        strategy.add_implementation(
+            wrap_compute_dense(topi.cuda.dense_small_batch),
+            wrap_topi_schedule(topi.cuda.schedule_dense_small_batch),
+            name="dense_small_batch.cuda")
+        b = inputs[0].shape[0]
+        with SpecializedCondition(b >= 32):
+            strategy.add_implementation(
+                wrap_compute_dense(topi.cuda.dense_large_batch),
+                wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
+                name="dense_large_batch.cuda",
+                plevel=15)
+    if target.target_name == "cuda" and "cublas" in target.libs:
+        strategy.add_implementation(
+            wrap_compute_dense(topi.cuda.dense_cublas),
+            wrap_topi_schedule(topi.cuda.schedule_dense_cublas),
+            name="dense_cublas.cuda",
+            plevel=20)
+    return strategy
+
+@batch_matmul_strategy.register(["cuda", "gpu"])
+def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
+    """batch_matmul cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_batch_matmul(topi.nn.batch_matmul),
+        wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
+        name="batch_matmul.cuda",
+        plevel=10)
+    if target.target_name == "cuda" and "cublas" in target.libs:
+        strategy.add_implementation(
+            wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas),
+            wrap_topi_schedule(topi.generic.schedule_extern),
+            name="batch_matmul_cublas.cuda",
+            plevel=15)
+    return strategy
+
+@argsort_strategy.register(["cuda", "gpu"])
+def argsort_strategy_cuda(attrs, inputs, out_type, target):
+    """argsort cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_argsort(topi.cuda.argsort),
+        wrap_topi_schedule(topi.cuda.schedule_argsort),
+        name="argsort.cuda")
+    return strategy
+
+@topk_strategy.register(["cuda", "gpu"])
+def topk_strategy_cuda(attrs, inputs, out_type, target):
+    """topk cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_topk(topi.cuda.topk),
+                                wrap_topi_schedule(topi.cuda.schedule_topk),
+                                name="topk.cuda")
+    return strategy
+
+@multibox_prior_strategy.register(["cuda", "gpu"])
+def multibox_prior_strategy_cuda(attrs, inputs, out_type, target):
+    """multibox_prior cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_multibox_prior(topi.cuda.multibox_prior),
+        wrap_topi_schedule(topi.cuda.schedule_multibox_prior),
+        name="multibox_prior.cuda")
+    return strategy
+
+@multibox_transform_loc_strategy.register(["cuda", "gpu"])
+def multibox_transform_loc_strategy_cuda(attrs, inputs, out_type, target):
+    """multibox_transform_loc cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_multibox_transform_loc(topi.cuda.multibox_transform_loc),
+        wrap_topi_schedule(topi.cuda.schedule_multibox_transform_loc),
+        name="multibox_transform_loc.cuda")
+    return strategy
+
+@get_valid_counts_strategy.register(["cuda", "gpu"])
+def get_valid_counts_strategy_cuda(attrs, inputs, out_type, target):
+    """get_valid_counts cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_get_valid_counts(topi.cuda.get_valid_counts),
+        wrap_topi_schedule(topi.cuda.schedule_get_valid_counts),
+        name="get_valid_counts.cuda")
+    return strategy
+
+@nms_strategy.register(["cuda", "gpu"])
+def nms_strategy_cuda(attrs, inputs, out_type, target):
+    """nms cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_nms(topi.cuda.non_max_suppression),
+        wrap_topi_schedule(topi.cuda.schedule_nms),
+        name="nms.cuda")
+    return strategy
+
+@roi_align_strategy.register(["cuda", "gpu"])
+def roi_align_strategy_cuda(attrs, inputs, out_type, target):
+    """roi_align cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
+                                wrap_topi_schedule(topi.cuda.schedule_roi_align),
+                                name="roi_align_nchw.cuda")
+    return strategy
+
+@schedule_roi_pool.register(["cuda", "gpu"])
+def schedule_roi_pool_cuda(attrs, outs, target):
+    """schedule roi_pool for cuda"""
+    with target:
+        return topi.cuda.schedule_roi_pool(outs)
+
+@proposal_strategy.register(["cuda", "gpu"])
+def proposal_strategy_cuda(attrs, inputs, out_type, target):
+    """proposal cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_proposal(topi.cuda.proposal),
+                                wrap_topi_schedule(topi.cuda.schedule_proposal),
+                                name="proposal.cuda")
+    return strategy
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
new file mode 100644 (file)
index 0000000..312ce95
--- /dev/null
@@ -0,0 +1,749 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of generic operator strategy."""
+# pylint: disable=invalid-name,unused-argument
+import logging
+
+import re
+import topi
+from topi.util import get_const_int, get_const_float, get_const_tuple, get_float_tuple
+from .. import op as _op
+from ....target import generic_func, override_native_generic_func
+
+logger = logging.getLogger('strategy')
+
+def wrap_topi_schedule(topi_schedule):
+    """Wrap TOPI schedule which doesn't use attrs"""
+    def wrapper(attrs, outs, target):
+        with target:
+            return topi_schedule(outs)
+    return wrapper
+
+def get_conv2d_in_channels(data_shape, data_layout):
+    """Get conv2d input channels"""
+    data_shape = get_const_tuple(data_shape)
+    if len(data_shape) == 4:
+        idx = data_layout.find("C")
+        assert idx >= 0, "Invalid conv2d data layout {}".format(data_layout)
+        return data_shape[idx]
+    if re.match(r"NCHW\d*c", data_layout):
+        # NCHW[8]c
+        return data_shape[1] * data_shape[4]
+    raise ValueError("Unknown conv2d data layout {}".format(data_layout))
+
+def get_conv2d_out_channels(kernel_shape, kernel_layout):
+    """Get conv2d output channels"""
+    kernel_shape = get_const_tuple(kernel_shape)
+    if len(kernel_shape) == 4:
+        idx = kernel_layout.find("O")
+        assert idx >= 0, "Invalid conv2d kernel layout {}".format(kernel_layout)
+        return kernel_shape[idx]
+    if re.match(r"OIHW\d*i\d*o", kernel_layout):
+        return kernel_shape[0] * kernel_shape[5]
+    if re.match(r"OIHW\d*o", kernel_layout):
+        return kernel_shape[0] * kernel_shape[4]
+    raise ValueError("Unknown conv2d kernel layout {}".format(kernel_layout))
+
+def is_depthwise_conv2d(data_shape, data_layout, kernel_shape, kernel_layout, groups):
+    ic = get_conv2d_in_channels(data_shape, data_layout)
+    oc = get_conv2d_out_channels(kernel_shape, kernel_layout)
+    return ic == oc == groups
+
+@generic_func
+def schedule_injective(attrs, outs, target):
+    """Schedule injective ops"""
+    with target:
+        return topi.generic.schedule_injective(outs)
+
+@generic_func
+def schedule_reduce(attrs, outs, target):
+    """Schedule reduction ops"""
+    with target:
+        return topi.generic.schedule_reduce(outs)
+
+_op._schedule_injective = schedule_injective
+_op._schedule_reduce = schedule_reduce
+
+# concatenate
+@generic_func
+def schedule_concatenate(attrs, outs, target):
+    """Schedule concatenate op"""
+    with target:
+        return topi.generic.schedule_injective(outs)
+
+# pool
+@generic_func
+def schedule_pool(attrs, outs, target):
+    """Schedule pooling ops"""
+    with target:
+        return topi.generic.schedule_pool(outs, attrs.layout)
+
+# pool_grad
+@generic_func
+def schedule_pool_grad(attrs, outs, target):
+    """Schedule pooling gradient ops"""
+    with target:
+        return topi.generic.schedule_pool_grad(outs)
+
+# adaptive pool
+@generic_func
+def schedule_adaptive_pool(attrs, outs, target):
+    """Schedule adaptive pooling ops"""
+    with target:
+        return topi.generic.schedule_adaptive_pool(outs)
+
+# softmax
+@generic_func
+def schedule_softmax(attrs, outs, target):
+    """Schedule softmax"""
+    with target:
+        return topi.generic.schedule_softmax(outs)
+
+# lrn
+@generic_func
+def schedule_lrn(attrs, outs, target):
+    """Schedule LRN op"""
+    with target:
+        return topi.generic.schedule_lrn(outs)
+
+# bitpack
+@generic_func
+def schedule_bitpack(attrs, outs, target):
+    """Schedule bitpack"""
+    with target:
+        return topi.generic.schedule_bitpack(outs)
+
+# conv2d
+def wrap_compute_conv2d(topi_compute, need_data_layout=False, need_out_layout=False,
+                        has_groups=False):
+    """Wrap conv2d topi compute"""
+    def _compute_conv2d(attrs, inputs, out_type):
+        padding = get_const_tuple(attrs.padding)
+        strides = get_const_tuple(attrs.strides)
+        dilation = get_const_tuple(attrs.dilation)
+        data_layout = attrs.get_str("data_layout")
+        out_layout = attrs.get_str("out_layout")
+        out_dtype = attrs.out_dtype
+        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
+                     else out_dtype)
+        args = [inputs[0], inputs[1], strides, padding, dilation]
+        if has_groups:
+            args.append(attrs.groups)
+        if need_data_layout:
+            args.append(data_layout)
+        if need_out_layout:
+            args.append(out_layout)
+        args.append(out_dtype)
+        return [topi_compute(*args)]
+    return _compute_conv2d
+
+@override_native_generic_func("conv2d_strategy")
+def conv2d_strategy(attrs, inputs, out_type, target):
+    """conv2d generic strategy"""
+    logger.warning("conv2d is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    dilation = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    (dilation_h, dilation_w) = dilation
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if groups == 1:
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_nchw),
+                wrap_topi_schedule(topi.generic.schedule_conv2d_nchw),
+                name="conv2d_nchw.generic")
+        elif layout == "NHWC":
+            assert kernel_layout == "HWIO"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_nhwc),
+                wrap_topi_schedule(topi.generic.schedule_conv2d_nhwc),
+                name="conv2d_nhwc.generic")
+        elif layout == "HWCN":
+            assert kernel_layout == "HWIO"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_hwcn),
+                wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn),
+                name="conv2d_hwcn.generic")
+        else:
+            raise RuntimeError("Unsupported conv2d layout {}".format(layout))
+    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
+                wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
+                name="depthwise_conv2d_nchw.generic")
+        elif layout == "NHWC":
+            assert kernel_layout == "HWOI"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
+                wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
+                name="depthwise_conv2d_nhwc.generic")
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
+    else: # group_conv2d
+        if layout == 'NCHW':
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True),
+                wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw),
+                name="group_conv2d_nchw.generic")
+        else:
+            raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
+    return strategy
+
+# conv2d_NCHWc
+@override_native_generic_func("conv2d_NCHWc_strategy")
+def conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
+    """conv2d_NCHWc generic strategy"""
+    logger.warning("conv2d_NCHWc is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    if inputs[0].dtype == "int8" or inputs[0].dtype == "uint8":
+        strategy.add_implementation(
+            wrap_compute_conv2d(topi.nn.conv2d_NCHWc_int8, True, True),
+            wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc_int8),
+            name="conv2d_NCHWc_int8.generic")
+    else:
+        strategy.add_implementation(
+            wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
+            wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc),
+            name="conv2d_NCHWc.generic")
+    return strategy
+
+# depthwise_conv2d_NCHWc
+@override_native_generic_func("depthwise_conv2d_NCHWc_strategy")
+def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
+    """depthwise_conv2d generic strategy"""
+    logger.warning("depthwise_conv2d_NCHWc is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_conv2d(topi.nn.depthwise_conv2d_NCHWc, True, True),
+        wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_NCHWc),
+        name="depthwise_conv2d_NCHWc.generic")
+    return strategy
+
+# conv2d_winograd_without_weight_transform
+@override_native_generic_func("conv2d_winograd_without_weight_transform_strategy")
+def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
+    """conv2d_winograd_without_weight_transfrom generic strategy"""
+    raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform")
+
+# conv2d_winograd_weight_transform
+@generic_func
+def schedule_conv2d_winograd_weight_transform(attrs, outs, target):
+    """Schedule conv2d_winograd_weight_transform"""
+    with target:
+        return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
+
+# conv2d_winograd_nnpack_weight_transform
+@generic_func
+def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
+    """Schedule conv2d_winograd_nnpack_weight_transform"""
+    with target:
+        return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
+
+# deformable_conv2d
+def wrap_compute_deformable_conv2d(topi_compute):
+    """wrap deformable_conv2d topi compute"""
+    def _compute_deformable_conv2d(attrs, inputs, out_dtype):
+        assert attrs.data_layout == "NCHW"
+        padding = get_const_tuple(attrs.padding)
+        strides = get_const_tuple(attrs.strides)
+        dilation = get_const_tuple(attrs.dilation)
+        deformable_groups = attrs.deformable_groups
+        groups = attrs.groups
+        out_dtype = attrs.out_dtype
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
+        out = topi_compute(inputs[0], inputs[1], inputs[2], strides, padding,
+                           dilation, deformable_groups, groups, out_dtype)
+        return [out]
+    return _compute_deformable_conv2d
+
+@override_native_generic_func("deformable_conv2d_strategy")
+def deformable_conv2d_strategy(attrs, inputs, out_type, target):
+    """deformable_conv2d generic strategy"""
+    logger.warning("deformable_conv2d is not optimized for this platform.")
+    layout = attrs.data_layout
+    assert layout == "NCHW"
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw),
+        wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw),
+        name="deformable_conv2d.generic")
+    return strategy
+
+# conv2d_transpose
+def wrap_compute_conv2d_transpose(topi_compute):
+    """wrap conv2d_transpose topi compute"""
+    def compute_conv2d_transpose(attrs, inputs, out_dtype):
+        """Compute definition of conv2d_transpose"""
+        padding = get_const_tuple(attrs.padding)
+        strides = get_const_tuple(attrs.strides)
+        out_dtype = attrs.out_dtype
+        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
+                     else out_dtype)
+        out = topi_compute(
+            inputs[0], inputs[1], strides, padding, out_dtype)
+        output_padding = get_const_tuple(attrs.output_padding)
+        out = topi.nn.pad(out, [0, 0, 0, 0],
+                          [0, 0, output_padding[0], output_padding[1]])
+        return [out]
+    return compute_conv2d_transpose
+
+@override_native_generic_func("conv2d_transpose_strategy")
+def conv2d_transpose_strategy(attrs, inputs, out_type, target):
+    """conv2d_transpose generic strategy"""
+    logger.warning("conv2d_transpose is not optimized for this platform.")
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    assert layout == "NCHW", "only support nchw for now"
+    assert dilation == (1, 1), "not support dilate now"
+    assert groups == 1, "only support groups == 1 for now"
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
+        wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
+        name="conv2d_transpose_nchw.generic")
+    return strategy
+
+# conv3d
+def wrap_compute_conv3d(topi_compute, need_layout=False):
+    """wrap conv3d topi compute"""
+    def _compute_conv3d(attrs, inputs, out_type):
+        padding = get_const_tuple(attrs.padding)
+        strides = get_const_tuple(attrs.strides)
+        dilation = get_const_tuple(attrs.dilation)
+        groups = attrs.groups
+        layout = attrs.data_layout
+        out_dtype = attrs.out_dtype
+        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
+                     else out_dtype)
+
+        (dilation_d, dilation_h, dilation_w) = dilation
+        if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
+            raise ValueError("Dilation should be positive value")
+        if groups != 1:
+            raise ValueError("Not support arbitrary group number for conv3d")
+        if need_layout:
+            out = topi_compute(inputs[0], inputs[1], strides, padding, dilation,
+                               layout, out_dtype)
+        else:
+            out = topi_compute(inputs[0], inputs[1], strides, padding, dilation,
+                               out_dtype)
+        return [out]
+    return _compute_conv3d
+
+@override_native_generic_func("conv3d_strategy")
+def conv3d_strategy(attrs, inputs, out_type, target):
+    """conv3d generic strategy"""
+    logger.warning("conv3d is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    layout = attrs.data_layout
+    if layout == "NCDHW":
+        strategy.add_implementation(
+            wrap_compute_conv3d(topi.nn.conv3d_ncdhw),
+            wrap_topi_schedule(topi.generic.schedule_conv3d_ncdhw),
+            name="conv3d_ncdhw.generic")
+    elif layout == "NDHWC":
+        strategy.add_implementation(
+            wrap_compute_conv3d(topi.nn.conv3d_ndhwc),
+            wrap_topi_schedule(topi.generic.schedule_conv3d_ndhwc),
+            name="conv3d_ndhwc.generic")
+    else:
+        raise ValueError("Not support this layout {} yet".format(layout))
+    return strategy
+
+# conv1d
+def wrap_compute_conv1d(topi_compute):
+    """wrap conv1d topi compute"""
+    def _compute_conv1d(attrs, inputs, out_type):
+        """Compute definition of conv1d"""
+        strides = get_const_tuple(attrs.strides)
+        padding = get_const_tuple(attrs.padding)
+        dilation = get_const_tuple(attrs.dilation)
+        out_dtype = attrs.out_dtype
+        out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
+                     else out_dtype)
+        return [topi_compute(inputs[0], inputs[1], strides, padding, dilation,
+                             out_dtype)]
+    return _compute_conv1d
+
+@override_native_generic_func("conv1d_strategy")
+def conv1d_strategy(attrs, inputs, out_type, target):
+    """conv1d generic strategy"""
+    logger.warning("conv1d is not optimized for this platform.")
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    if dilation[0] < 1:
+        raise ValueError("dilation should be a positive value")
+    strategy = _op.OpStrategy()
+    if layout == "NCW":
+        strategy.add_implementation(
+            wrap_compute_conv1d(topi.nn.conv1d_ncw),
+            wrap_topi_schedule(topi.generic.schedule_conv1d_ncw),
+            name="conv1d_ncw.generic")
+    elif layout == "NWC":
+        strategy.add_implementation(
+            wrap_compute_conv1d(topi.nn.conv1d_nwc),
+            wrap_topi_schedule(topi.generic.schedule_conv1d_nwc),
+            name="conv1d_nwc.generic")
+    else:
+        raise ValueError("Unsupported conv1d layout {}".format(layout))
+    return strategy
+
+# conv1d_transpose
+def wrap_compute_conv1d_transpose(topi_compute):
+    """wrap conv1d_transpose topi compute"""
+    def _compute_conv1d_tranpsoe(attrs, inputs, out_type):
+        padding = get_const_tuple(attrs.padding)
+        strides = get_const_tuple(attrs.strides)
+        out_dtype = attrs.out_dtype
+        out_dtype = (inputs[0].dtype if out_dtype in ("same", "") else out_dtype)
+        out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype)
+        output_padding = get_const_tuple(attrs.output_padding)
+        out = topi.nn.pad(out, [0, 0, 0], [0, 0, output_padding[0]])
+        return [out]
+    return _compute_conv1d_tranpsoe
+
+@override_native_generic_func("conv1d_transpose_strategy")
+def conv1d_transpose_strategy(attrs, inputs, out_type, target):
+    """conv1d_transpose generic strategy"""
+    logger.warning("conv1d_transpose is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    assert layout == "NCW", "conv1d_transpose ncw only supported"
+    assert dilation == (1,), "conv1d_transpose dilation is not supported"
+    assert groups == 1, "conv1d_transpose groups == 1 only supported"
+    strategy.add_implementation(wrap_compute_conv1d_transpose(topi.nn.conv1d_transpose_ncw),
+                                wrap_topi_schedule(topi.generic.schedule_conv1d_transpose_ncw),
+                                name="conv1d_transpose_ncw.generic")
+    return strategy
+
+# dense
+def wrap_compute_dense(topi_compute):
+    """wrap dense topi compute"""
+    def _compute_dense(attrs, inputs, out_type):
+        """Compute definition of dense"""
+        out_dtype = attrs.out_dtype
+        out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
+        return [topi_compute(inputs[0], inputs[1], None, out_dtype)]
+    return _compute_dense
+
+@override_native_generic_func("dense_strategy")
+def dense_strategy(attrs, inputs, out_type, target):
+    """dense generic strategy"""
+    logger.warning("dense is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_dense(topi.nn.dense),
+                                wrap_topi_schedule(topi.generic.schedule_dense),
+                                name="dense.generic")
+    return strategy
+
+# batch_matmul
+def wrap_compute_batch_matmul(topi_compute):
+    """wrap batch_matmul topi compute"""
+    def _compute_batch_matmul(attrs, inputs, out_type):
+        return [topi_compute(inputs[0], inputs[1])]
+    return _compute_batch_matmul
+
+@override_native_generic_func("batch_matmul_strategy")
+def batch_matmul_strategy(attrs, inputs, out_type, target):
+    """batch_matmul generic strategy"""
+    logger.warning("batch_matmul is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_batch_matmul(topi.nn.batch_matmul),
+                                wrap_topi_schedule(topi.generic.schedule_batch_matmul),
+                                name="batch_matmul.generic")
+    return strategy
+
+# sparse_dense
+@generic_func
+def schedule_sparse_dense(attrs, outs, target):
+    """schedule sparse_dense"""
+    with target:
+        return topi.generic.schedule_sparse_dense(outs)
+
+# sparse_transpose
+@generic_func
+def schedule_sparse_transpose(attrs, outs, target):
+    """schedule sparse_transpose"""
+    with target:
+        return topi.generic.schedule_sparse_transpose(outs)
+
+# argsort
+def wrap_compute_argsort(topi_compute):
+    """Wrap argsort topi compute"""
+    def _compute_argsort(attrs, inputs, _):
+        axis = get_const_int(attrs.axis)
+        is_ascend = bool(get_const_int(attrs.is_ascend))
+        dtype = attrs.dtype
+        return [topi_compute(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
+    return _compute_argsort
+
+@override_native_generic_func("argsort_strategy")
+def argsort_strategy(attrs, inputs, out_type, target):
+    """argsort generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_argsort(topi.argsort),
+                                wrap_topi_schedule(topi.generic.schedule_argsort),
+                                name="argsort.generic")
+    return strategy
+
+# topk
+def wrap_compute_topk(topi_compute):
+    """Wrap topk compute"""
+    def _compute_topk(attrs, inputs, out_type):
+        k = get_const_int(attrs.k)
+        axis = get_const_int(attrs.axis)
+        ret_type = attrs.ret_type
+        is_ascend = bool(get_const_int(attrs.is_ascend))
+        dtype = attrs.dtype
+        out = topi_compute(inputs[0], k, axis, ret_type, is_ascend, dtype)
+        out = out if isinstance(out, list) else [out]
+        return out
+    return _compute_topk
+
+@override_native_generic_func("topk_strategy")
+def topk_strategy(attrs, inputs, out_type, target):
+    """topk generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_topk(topi.topk),
+                                wrap_topi_schedule(topi.generic.schedule_topk),
+                                name="topk.generic")
+    return strategy
+
+# multibox_prior
+def wrap_compute_multibox_prior(topi_compute):
+    """Wrap multibox_prior compute"""
+    def _compute_multibox_prior(attrs, inputs, _):
+        """Compute definition of multibox_prior"""
+        sizes = get_float_tuple(attrs.sizes)
+        ratios = get_float_tuple(attrs.ratios)
+        steps = get_float_tuple(attrs.steps)
+        offsets = get_float_tuple(attrs.offsets)
+        clip = bool(get_const_int(attrs.clip))
+        return [topi_compute(inputs[0], sizes, ratios, steps, offsets, clip)]
+    return _compute_multibox_prior
+
+@override_native_generic_func("multibox_prior_strategy")
+def multibox_prior_strategy(attrs, inputs, out_type, target):
+    """multibox_prior generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_multibox_prior(topi.vision.ssd.multibox_prior),
+                                wrap_topi_schedule(topi.generic.schedule_multibox_prior),
+                                name="multibox_prior.generic")
+    return strategy
+
+# multibox_transform_loc
+def wrap_compute_multibox_transform_loc(topi_compute):
+    """Wrap multibox_transform_loc compute"""
+    def _compute_multibox_transform_loc(attrs, inputs, _):
+        """Compute definition of multibox_detection"""
+        clip = bool(get_const_int(attrs.clip))
+        threshold = get_const_float(attrs.threshold)
+        variances = get_float_tuple(attrs.variances)
+        return topi_compute(
+            inputs[0], inputs[1], inputs[2], clip, threshold, variances)
+    return _compute_multibox_transform_loc
+
+@override_native_generic_func("multibox_transform_loc_strategy")
+def multibox_transform_loc_strategy(attrs, inputs, out_type, target):
+    """schedule multibox_transform_loc"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_multibox_transform_loc(topi.vision.ssd.multibox_transform_loc),
+        wrap_topi_schedule(topi.generic.schedule_multibox_transform_loc),
+        name="multibox_transform_loc.generic")
+    return strategy
+
+# get_valid_counts
+def wrap_compute_get_valid_counts(topi_compute):
+    """wrap get_valid_counts topi compute"""
+    def _compute_get_valid_counts(attrs, inputs, out_type):
+        score_threshold = get_const_float(attrs.score_threshold)
+        id_index = get_const_int(attrs.id_index)
+        score_index = get_const_int(attrs.score_index)
+        return topi_compute(inputs[0], score_threshold, id_index, score_index)
+    return _compute_get_valid_counts
+
+@override_native_generic_func("get_valid_counts_strategy")
+def get_valid_counts_strategy(attrs, inputs, out_type, target):
+    """get_valid_counts generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_get_valid_counts(topi.vision.get_valid_counts),
+                                wrap_topi_schedule(topi.generic.schedule_get_valid_counts),
+                                name="get_valid_counts.generic")
+    return strategy
+
+# non-maximum suppression
+def wrap_compute_nms(topi_compute):
+    """wrap nms topi compute"""
+    def _compute_nms(attrs, inputs, out_type):
+        return_indices = bool(get_const_int(attrs.return_indices))
+        max_output_size = get_const_int(attrs.max_output_size)
+        iou_threshold = get_const_float(attrs.iou_threshold)
+        force_suppress = bool(get_const_int(attrs.force_suppress))
+        top_k = get_const_int(attrs.top_k)
+        coord_start = get_const_int(attrs.coord_start)
+        score_index = get_const_int(attrs.score_index)
+        id_index = get_const_int(attrs.id_index)
+        invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
+        return [topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold,
+                             force_suppress, top_k, coord_start, score_index,
+                             id_index, return_indices, invalid_to_bottom)]
+    return _compute_nms
+
+@override_native_generic_func("non_max_suppression_strategy")
+def nms_strategy(attrs, inputs, out_type, target):
+    """nms generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_nms(topi.vision.non_max_suppression),
+                                wrap_topi_schedule(topi.generic.schedule_nms),
+                                name="nms.generic")
+    return strategy
+
+# roi_align
+def wrap_compute_roi_align(topi_compute):
+    """wrap roi_align topi compute"""
+    def _compute_roi_align(attrs, inputs, out_type):
+        assert attrs.layout == "NCHW"
+        pooled_size = get_const_tuple(attrs.pooled_size)
+        return [topi_compute(inputs[0], inputs[1],
+                             pooled_size=pooled_size,
+                             spatial_scale=attrs.spatial_scale,
+                             sample_ratio=attrs.sample_ratio)]
+    return _compute_roi_align
+
+@override_native_generic_func("roi_align_strategy")
+def roi_align_strategy(attrs, inputs, out_type, target):
+    """roi_align generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
+                                wrap_topi_schedule(topi.generic.schedule_roi_align),
+                                name="roi_align.generic")
+    return strategy
+
+# roi_pool
+@generic_func
+def schedule_roi_pool(attrs, outs, target):
+    """schedule roi_pool"""
+    with target:
+        return topi.generic.schedule_roi_pool(outs)
+
+# proposal
+def wrap_compute_proposal(topi_compute):
+    """wrap proposal topi compute"""
+    def _compute_proposal(attrs, inputs, out_type):
+        scales = get_float_tuple(attrs.scales)
+        ratios = get_float_tuple(attrs.ratios)
+        feature_stride = attrs.feature_stride
+        threshold = attrs.threshold
+        rpn_pre_nms_top_n = attrs.rpn_pre_nms_top_n
+        rpn_post_nms_top_n = attrs.rpn_post_nms_top_n
+        rpn_min_size = attrs.rpn_min_size
+        iou_loss = bool(get_const_int(attrs.iou_loss))
+        return [topi_compute(inputs[0], inputs[1], inputs[2], scales, ratios,
+                             feature_stride, threshold, rpn_pre_nms_top_n,
+                             rpn_post_nms_top_n, rpn_min_size, iou_loss)]
+    return _compute_proposal
+
+@override_native_generic_func("proposal_strategy")
+def proposal_strategy(attrs, inputs, out_type, target):
+    """proposal generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_proposal(topi.vision.rcnn.proposal),
+                                wrap_topi_schedule(topi.generic.schedule_proposal),
+                                name="proposal.generic")
+    return strategy
+
+# argwhere
+@generic_func
+def schedule_argwhere(attrs, outs, target):
+    """schedule argwhere"""
+    with target:
+        return topi.generic.schedule_argwhere(outs)
+
+# bitserial_conv2d
+def wrap_compute_bitserial_conv2d(topi_compute):
+    """wrap bitserial_conv2d topi compute"""
+    def compute_bitserial_conv2d(attrs, inputs, out_dtype):
+        """Compute definition for bitserial conv2d."""
+        padding = get_const_tuple(attrs.padding)
+        strides = get_const_tuple(attrs.strides)
+        activation_bits = attrs.activation_bits
+        weight_bits = attrs.weight_bits
+        pack_dtype = attrs.pack_dtype
+        out_dtype = attrs.out_dtype
+        unipolar = attrs.unipolar
+        return [topi_compute(inputs[0], inputs[1], strides, padding, activation_bits,
+                             weight_bits, pack_dtype, out_dtype, unipolar)]
+    return compute_bitserial_conv2d
+
+@override_native_generic_func("bitserial_conv2d_strategy")
+def bitserial_conv2d_strategy(attrs, inputs, out_type, target):
+    """bitserial_conv2d generic strategy"""
+    logger.warning("bitserial_conv2d is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    layout = attrs.data_layout
+    if layout == "NCHW":
+        strategy.add_implementation(
+            wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nchw),
+            wrap_topi_schedule(topi.generic.schedule_bitserial_conv2d_nchw),
+            name="bitserial_conv2d_nchw.generic")
+    elif layout == "NHWC":
+        strategy.add_implementation(
+            wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nhwc),
+            wrap_topi_schedule(topi.generic.schedule_bitserial_conv2d_nhwc),
+            name="bitserial_conv2d_nhwc.generic")
+    else:
+        raise ValueError("Data layout {} not supported.".format(layout))
+    return strategy
+
+# bitserial_dense
+def wrap_compute_bitserial_dense(topi_compute):
+    """wrap bitserial_dense topi compute"""
+    def compute_bitserial_dense(attrs, inputs, out_type):
+        """Compute definition of bitserial dense"""
+        data_bits = attrs.data_bits
+        weight_bits = attrs.weight_bits
+        pack_dtype = attrs.pack_dtype
+        out_dtype = attrs.out_dtype
+        out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
+        unipolar = attrs.unipolar
+        return [topi_compute(inputs[0], inputs[1], data_bits, weight_bits,
+                             pack_dtype, out_dtype, unipolar)]
+    return compute_bitserial_dense
+
+@override_native_generic_func("bitserial_dense_strategy")
+def bitserial_dense_strategy(attrs, inputs, out_type, target):
+    """bitserial_dense generic strategy"""
+    logger.warning("bitserial_dense is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_bitserial_dense(topi.nn.bitserial_dense),
+        wrap_topi_schedule(topi.generic.schedule_bitserial_dense),
+        name="bitserial_dense.generic")
+    return strategy
diff --git a/python/tvm/relay/op/strategy/hls.py b/python/tvm/relay/op/strategy/hls.py
new file mode 100644 (file)
index 0000000..514902b
--- /dev/null
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of HLS operator strategy."""
+# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
+import topi
+from .generic import *
+from .. import op as _op
+
+@schedule_injective.register("hls")
+def schedule_injective_hls(attrs, outs, target):
+    """schedule injective ops for hls"""
+    with target:
+        return topi.hls.schedule_injective(outs)
+
+@schedule_reduce.register("hls")
+def schedule_reduce_hls(attrs, outs, target):
+    """schedule reduction ops for hls"""
+    with target:
+        return topi.hls.schedule_reduce(outs)
+
+@schedule_concatenate.register("hls")
+def schedule_concatenate_hls(attrs, outs, target):
+    """schedule concatenate for hls"""
+    with target:
+        return topi.hls.schedule_injective(outs)
+
+@schedule_pool.register("hls")
+def schedule_pool_hls(attrs, outs, target):
+    """schedule pooling ops for hls"""
+    with target:
+        return topi.hls.schedule_pool(outs, attrs.layout)
+
+@schedule_adaptive_pool.register("hls")
+def schedule_adaptive_pool_hls(attrs, outs, target):
+    """schedule adaptive pooling ops for hls"""
+    with target:
+        return topi.hls.schedule_adaptive_pool(outs)
+
+@schedule_softmax.register("hls")
+def schedule_softmax_hls(attrs, outs, target):
+    """schedule softmax for hls"""
+    with target:
+        return topi.hls.schedule_softmax(outs)
+
+@override_native_generic_func("conv2d_strategy")
+def conv2d_strategy_hls(attrs, inputs, out_type, target):
+    """conv2d hls strategy"""
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    dilation = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    (dilation_h, dilation_w) = dilation
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if groups == 1:
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_nchw),
+                wrap_topi_schedule(topi.hls.schedule_conv2d_nchw),
+                name="conv2d_nchw.hls")
+        elif layout == "NHWC":
+            assert kernel_layout == "HWIO"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_nhwc),
+                wrap_topi_schedule(topi.hls.schedule_conv2d_nhwc),
+                name="conv2d_nhwc.hls")
+        else:
+            raise RuntimeError("Unsupported conv2d layout {}".format(layout))
+    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
+                wrap_topi_schedule(topi.hls.schedule_depthwise_conv2d_nchw),
+                name="depthwise_conv2d_nchw.hls")
+        elif layout == "NHWC":
+            assert kernel_layout == "HWOI"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
+                wrap_topi_schedule(topi.hls.schedule_depthwise_conv2d_nhwc),
+                name="depthwise_nhwc.hls")
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
+    else: # group_conv2d
+        raise RuntimeError("group_conv2d is not supported for hls")
+    return strategy
+
+@override_native_generic_func("conv2d_NCHWc_strategy")
+def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, target):
+    """conv2d_NCHWc hls strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
+        wrap_topi_schedule(topi.hls.schedule_conv2d_NCHWc),
+        name="conv2d_NCHWc.hls")
+    return strategy
+
+@conv2d_transpose_strategy.register("hls")
+def conv2d_transpose_strategy_hls(attrs, inputs, out_type, target):
+    """conv2d_transpose hls strategy"""
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    assert layout == "NCHW", "only support nchw for now"
+    assert dilation == (1, 1), "not support dilate now"
+    assert groups == 1, "only support groups == 1 for now"
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
+        wrap_topi_schedule(topi.hls.schedule_conv2d_transpose_nchw),
+        name="conv2d_transpose_nchw.hls")
+    return strategy
+
+@dense_strategy.register("hls")
+def dense_strategy_hls(attrs, inputs, out_type, target):
+    """dense hls strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_dense(topi.nn.dense),
+                                wrap_topi_schedule(topi.hls.schedule_dense),
+                                name="dense.hls")
+    return strategy
+
+@bitserial_conv2d_strategy.register("hls")
+def bitserial_conv2d_strategy_hls(attrs, inputs, out_type, target):
+    """bitserial_conv2d hls strategy"""
+    strategy = _op.OpStrategy()
+    layout = attrs.data_layout
+    if layout == "NCHW":
+        strategy.add_implementation(
+            wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nchw),
+            wrap_topi_schedule(topi.hls.schedule_bitserial_conv2d_nchw),
+            name="bitserial_conv2d_nchw.hls")
+    elif layout == "NHWC":
+        strategy.add_implementation(
+            wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nhwc),
+            wrap_topi_schedule(topi.hls.schedule_bitserial_conv2d_nhwc),
+            name="bitserial_conv2d_nhwc.hls")
+    else:
+        raise ValueError("Data layout {} not supported.".format(layout))
+    return strategy
diff --git a/python/tvm/relay/op/strategy/intel_graphics.py b/python/tvm/relay/op/strategy/intel_graphics.py
new file mode 100644 (file)
index 0000000..0ea8d85
--- /dev/null
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of x86 operator strategy."""
+# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
+import topi
+from .generic import *
+from .. import op as _op
+
+
+@conv2d_strategy.register("intel_graphics")
+def conv2d_strategy_intel_graphics(attrs, inputs, out_type, target):
+    """conv2d intel graphics strategy"""
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    dilation_h, dilation_w = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if groups == 1:
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.intel_graphics.conv2d_nchw),
+                wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_nchw),
+                name="conv2d_nchw.intel_graphics")
+            # conv2d_NCHWc won't work without alter op layout pass
+            # TODO(@Laurawly): fix this
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
+                wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
+                name="conv2d_NCHWc.intel_graphics",
+                plevel=5)
+        else:
+            raise RuntimeError("Unsupported conv2d layout {} for intel graphics".
+                               format(layout))
+    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.intel_graphics.depthwise_conv2d_nchw),
+                wrap_topi_schedule(topi.intel_graphics.schedule_depthwise_conv2d_nchw),
+                name="depthwise_conv2d_nchw.intel_graphics")
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
+    else: # group_conv2d
+        raise RuntimeError("group_conv2d is not supported for intel graphics")
+    return strategy
+
+@conv2d_NCHWc_strategy.register("intel_graphics")
+def conv2d_NCHWc_strategy_intel_graphics(attrs, inputs, out_type, target):
+    """conv2d_NCHWc intel_graphics strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
+        wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
+        name="conv2d_NCHWc.intel_graphics")
+    return strategy
diff --git a/python/tvm/relay/op/strategy/mali.py b/python/tvm/relay/op/strategy/mali.py
new file mode 100644 (file)
index 0000000..8f1fa29
--- /dev/null
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of mali operator strategy."""
+# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
+import re
+import topi
+from .generic import *
+from .. import op as _op
+
+@conv2d_strategy.register("mali")
+def conv2d_strategy_mali(attrs, inputs, out_type, target):
+    """conv2d mali strategy"""
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    dilation_h, dilation_w = attrs.get_int_tuple("dilation")
+    stride_h, stride_w = attrs.get_int_tuple("strides")
+    groups = attrs.groups
+    layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if groups == 1:
+        if layout == "NCHW":
+            if kernel_layout == "OIHW":
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack),
+                    wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_spatial_pack),
+                    name="conv2d_nchw_spatial_pack.mali")
+                # check if winograd algorithm is applicable
+                _, _, kh, kw = get_const_tuple(kernel.shape)
+                if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
+                   dilation_h == 1 and dilation_w == 1:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.mali.conv2d_nchw_winograd),
+                        wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd),
+                        name="conv2d_nchw_winograd.mali",
+                        plevel=15)
+            elif re.match(r"OIHW\d*o", kernel_layout):
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack),
+                    wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_spatial_pack),
+                    name="conv2d_nchw_spatial_pack.mali")
+            else:
+                raise RuntimeError("Unsupported weight layout {} for conv2d NCHW".
+                                   format(kernel_layout))
+        else:
+            raise RuntimeError("Unsupported conv2d layout {} for mali".format(layout))
+    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.mali.depthwise_conv2d_nchw),
+                wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nchw),
+                name="depthwise_conv2d_nchw.mali")
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d layout {} for mali".format(layout))
+    else: # group_conv2d
+        raise RuntimeError("group_conv2d is not supported for mali")
+    return strategy
+
+@conv2d_winograd_without_weight_transfrom_strategy.register("mali")
+def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_type, target):
+    """conv2d_winograd_without_weight_transfrom mali strategy"""
+    dilation = attrs.get_int_tuple("dilation")
+    groups = attrs.get_int("groups")
+    layout = attrs.data_layout
+    strides = attrs.get_int_tuple("strides")
+    kernel = inputs[1]
+    assert dilation == (1, 1), "Do not support dilate now"
+    assert strides == (1, 1), "Do not support strides now"
+    assert groups == 1, "Do not supoort arbitrary group number"
+    strategy = _op.OpStrategy()
+    if layout == "NCHW":
+        assert len(kernel.shape) == 5, "Kernel must be packed into 5-dim"
+        strategy.add_implementation(
+            wrap_compute_conv2d(topi.mali.conv2d_nchw_winograd),
+            wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd),
+            name="conv2d_nchw_winograd.mali")
+    else:
+        raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
+                           format(layout))
+    return strategy
+
+@dense_strategy.register("mali")
+def dense_strategy_mali(attrs, inputs, out_type, target):
+    """dense mali strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_dense(topi.mali.dense),
+                                wrap_topi_schedule(topi.mali.schedule_dense),
+                                name="dense.mali")
+    return strategy
diff --git a/python/tvm/relay/op/strategy/opengl.py b/python/tvm/relay/op/strategy/opengl.py
new file mode 100644 (file)
index 0000000..45e290c
--- /dev/null
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of OpenGL operator strategy."""
+# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
+import topi
+from .generic import *
+from .. import op as _op
+
+@schedule_injective.register("opengl")
+def schedule_injective_opengl(attrs, outs, target):
+    """schedule injective ops for opengl"""
+    with target:
+        return topi.opengl.schedule_injective(outs)
+
+@schedule_concatenate.register("opengl")
+def schedule_concatenate_opengl(attrs, outs, target):
+    """schedule concatenate for opengl"""
+    with target:
+        return topi.opengl.schedule_injective(outs)
+
+@schedule_pool.register("opengl")
+def schedule_pool_opengl(attrs, outs, target):
+    """schedule pooling ops for opengl"""
+    with target:
+        return topi.opengl.schedule_pool(outs, attrs.layout)
+
+@schedule_adaptive_pool.register("opengl")
+def schedule_adaptive_pool_opengl(attrs, outs, target):
+    """schedule adative pooling ops for opengl"""
+    with target:
+        return topi.opengl.schedule_adaptive_pool(outs)
+
+@schedule_softmax.register("opengl")
+def schedule_softmax_opengl(attrs, outs, target):
+    """schedule softmax for opengl"""
+    with target:
+        return topi.opengl.schedule_softmax(outs)
+
+@conv2d_strategy.register("opengl")
+def conv2d_strategy_opengl(attrs, inputs, out_type, target):
+    """conv2d opengl strategy"""
+    strategy = _op.OpStrategy()
+    groups = attrs.groups
+    layout = attrs.data_layout
+    assert groups == 1, "Don't support group conv2d on OpenGL"
+    assert layout == "NCHW", "Only support conv2d layout NCHW for OpenGL"
+    strategy.add_implementation(wrap_compute_conv2d(topi.nn.conv2d),
+                                wrap_topi_schedule(topi.opengl.schedule_conv2d_nchw),
+                                name="conv2d_nchw.opengl")
+    return strategy
+
+@dense_strategy.register("opengl")
+def dense_strategy_opengl(attrs, inputs, out_type, target):
+    """dense opengl strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_dense(topi.nn.dense),
+                                wrap_topi_schedule(topi.opengl.schedule_dense),
+                                name="dense.opengl")
+    return strategy
diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py
new file mode 100644 (file)
index 0000000..e11a688
--- /dev/null
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of ROCm operator strategy."""
+# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
+import topi
+from .generic import *
+from .. import op as _op
+
+@schedule_lrn.register("rocm")
+def schedule_lrn_rocm(attrs, outs, target):
+    """schedule LRN for rocm"""
+    with target:
+        return topi.rocm.schedule_lrn(outs)
+
+@conv2d_strategy.register("rocm")
+def conv2d_strategy_rocm(attrs, inputs, out_type, target):
+    """conv2d rocm strategy"""
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    dilation_h, dilation_w = attrs.get_int_tuple("dilation")
+    groups = attrs.groups
+    layout = attrs.data_layout
+    stride_h, stride_w = attrs.get_int_tuple("strides")
+    kernel_layout = attrs.kernel_layout
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if groups == 1:
+        if layout == "NCHW":
+            # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.conv2d_nchw),
+                wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
+                name="conv2d_nchw.cuda")
+            _, _, kh, kw = get_const_tuple(kernel.shape)
+            if kh <= 7 and kw <= 7 and kh == kw and stride_h == 1 and stride_w == 1:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
+                    wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
+                    name="conv2d_nchw_winograd.cuda",
+                    plevel=15)
+        elif layout == "HWCN":
+            assert kernel_layout == "HWIO"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
+                wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
+                name="conv2d_hwcn.cuda")
+        # TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
+        # elif layout == "NHWC":
+        #     assert kernel_layout == "HWIO"
+        #     strategy.add_implementation(
+        #         wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
+        #         wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
+        #         name="conv2d_nhwc.cuda")
+        elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
+            assert kernel_layout == "OIHW4o4i"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
+                wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
+                name="conv2d_NCHWc_int8.cuda")
+        else:
+            raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
+        # add miopen implementation
+        if "miopen" in target.libs:
+            if layout == "NCHW":
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
+                    wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
+                    name="conv2d_nchw_miopen.rocm",
+                    plevel=15)
+    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
+                wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
+                name="depthwise_conv2d_nchw.cuda")
+        elif layout == "NHWC":
+            assert kernel_layout == "HWOI"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
+                wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
+                name="depthwise_conv2d_nhwc.cuda")
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
+    else: # group_conv2d
+        if layout == 'NCHW':
+            # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
+                wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
+                name="group_conv2d_nchw.cuda")
+        elif layout == 'NCHW4c' and data.dtype in ["int8", "uint8"]:
+            assert kernel_layout == "OIHW4o4i"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
+                wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
+                name="group_conv2d_NCHWc_int8.cuda")
+        else:
+            raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
+    return strategy
+
+@dense_strategy.register("rocm")
+def dense_strategy_rocm(attrs, inputs, out_type, target):
+    """Dense strategy for ROCM"""
+    strategy = _op.OpStrategy()
+    assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
+
+    strategy.add_implementation(
+        wrap_compute_dense(topi.rocm.dense),
+        wrap_topi_schedule(topi.rocm.schedule_dense),
+        name="dense.rocm")
+    if target.target_name == "rocm" and "rocblas" in target.libs:
+        assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
+        strategy.add_implementation(
+            wrap_compute_dense(topi.rocm.dense_rocblas),
+            wrap_topi_schedule(topi.rocm.dense_rocblas),
+            name="dense_rocblas.rocm",
+            plevel=5)
+    return strategy
diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py
new file mode 100644 (file)
index 0000000..2fadb7f
--- /dev/null
@@ -0,0 +1,302 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of x86 operator strategy."""
+# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
+import logging
+
+import topi
+from tvm.te import SpecializedCondition
+from .generic import *
+from .. import op as _op
+
+logger = logging.getLogger('strategy')
+
+@schedule_injective.register("cpu")
+def schedule_injective_cpu(attrs, outs, target):
+    """schedule injective ops for x86"""
+    with target:
+        return topi.x86.schedule_injective(outs)
+
+@schedule_reduce.register("cpu")
+def schedule_reduce_cpu(attrs, outs, target):
+    """schedule reduction ops for x86"""
+    with target:
+        return topi.x86.schedule_reduce(outs)
+
+@schedule_concatenate.register("cpu")
+def schedule_concatenate_cpu(attrs, outs, target):
+    """schedule concatenate op for x86"""
+    with target:
+        return topi.x86.schedule_concatenate(outs)
+
+@schedule_pool.register("cpu")
+def schedule_pool_cpu(attrs, outs, target):
+    """schedule pooling ops for x86"""
+    with target:
+        return topi.x86.schedule_pool(outs, attrs.layout)
+
+@schedule_adaptive_pool.register("cpu")
+def schedule_adaptive_pool_cpu(attrs, outs, target):
+    """schedule adaptive pooling ops for x86"""
+    with target:
+        return topi.x86.schedule_adaptive_pool(outs)
+
+@schedule_softmax.register("cpu")
+def schedule_softmax_cpu(attrs, outs, target):
+    """schedule softmax for x86"""
+    with target:
+        return topi.x86.schedule_softmax(outs)
+
+@conv2d_strategy.register("cpu")
+def conv2d_strategy_cpu(attrs, inputs, out_type, target):
+    """conv2d x86 strategy"""
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    dilation_h, dilation_w = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if groups == 1:
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype):
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.x86.conv2d_nchw_int8),
+                    wrap_topi_schedule(topi.x86.schedule_conv2d_nchw_int8),
+                    name="conv2d_nchw_int8.x86")
+            else:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.x86.conv2d_nchw),
+                    wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
+                    name="conv2d_nchw.x86")
+        elif layout == "NHWC":
+            assert kernel_layout == "HWIO"
+            logger.warning("For x86 target, NCHW layout is recommended for conv2d.")
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_nhwc),
+                wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
+                name="conv2d_nhwc.x86")
+        elif layout == "HWCN":
+            assert kernel_layout == "HWIO"
+            logger.warning("conv2d HWCN layout is not optimized for x86.")
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_hwcn),
+                wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn),
+                name="conv2d_hwcn.generic")
+        else:
+            raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout))
+    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            channel_multiplier = get_const_tuple(inputs[1].shape)[1]
+            if channel_multiplier == 1 and dilation_h == 1 and dilation_w == 1:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw),
+                    wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_nchw),
+                    name="depthwise_conv2d_nchw.x86")
+            else:
+                logger.warning("For x86 target, depthwise_conv2d with channel "
+                               "multiplier greater than 1 is not optimized")
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
+                    wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
+                    name="depthwise_conv2d_nchw.generic")
+        elif layout == "NHWC":
+            assert kernel_layout == "HWOI"
+            logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.")
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
+                wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
+                name="depthwise_conv2d_nhwc.generic")
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
+    else: # group_conv2d
+        if layout == 'NCHW':
+            assert kernel_layout == "OIHW"
+            logger.warning("group_conv2d is not optimized for x86.")
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True),
+                wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw),
+                name="group_conv2d_nchw.generic")
+        else:
+            raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
+    return strategy
+
+@conv2d_NCHWc_strategy.register("cpu")
+def conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
+    """conv2d_NCHWc x86 strategy"""
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype):
+        strategy.add_implementation(
+            wrap_compute_conv2d(topi.x86.conv2d_NCHWc_int8, True, True),
+            wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc_int8),
+            name="conv2d_NCHWc_int8.x86")
+    else:
+        strategy.add_implementation(
+            wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
+            wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
+            name="conv2d_NCHWc.x86")
+    return strategy
+
+@depthwise_conv2d_NCHWc_strategy.register("cpu")
+def depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
+    """depthwise_conv2d x86 strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
+        wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
+        name="depthwise_conv2d_NCHWc.x86")
+    return strategy
+
+@conv2d_transpose_strategy.register("cpu")
+def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
+    """conv2d_transpose x86 strategy"""
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    groups = attrs.groups
+    assert layout == "NCHW", "only support nchw for now"
+    assert dilation == (1, 1), "not support dilate now"
+    assert groups == 1, "only support groups == 1 for now"
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
+        wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw),
+        name="conv2d_transpose_nchw.x86")
+    return strategy
+
+@conv3d_strategy.register("cpu")
+def conv3d_strategy_cpu(attrs, inputs, out_type, target):
+    """conv3d generic strategy"""
+    strategy = _op.OpStrategy()
+    layout = attrs.data_layout
+    if layout == "NCDHW":
+        logger.warning("conv3d with layout NCDHW is not optimized for x86.")
+        strategy.add_implementation(wrap_compute_conv3d(topi.nn.conv3d_ncdhw),
+                                    wrap_topi_schedule(topi.generic.schedule_conv3d_ncdhw),
+                                    name="conv3d_ncdhw.generic")
+    elif layout == "NDHWC":
+        strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ndhwc),
+                                    wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc),
+                                    name="conv3d_ndhwc.x86")
+    else:
+        raise ValueError("Not support this layout {} yet".format(layout))
+    return strategy
+
+@conv1d_strategy.register("cpu")
+def conv1d_strategy_cpu(attrs, inputs, out_type, target):
+    """conv1d x86 strategy"""
+    layout = attrs.data_layout
+    dilation = get_const_tuple(attrs.dilation)
+    if dilation[0] < 1:
+        raise ValueError("dilation should be a positive value")
+    strategy = _op.OpStrategy()
+    if layout == "NCW":
+        strategy.add_implementation(wrap_compute_conv1d(topi.nn.conv1d_ncw),
+                                    wrap_topi_schedule(topi.x86.schedule_conv1d_ncw),
+                                    name="conv1d_ncw.x86")
+    elif layout == "NWC":
+        strategy.add_implementation(wrap_compute_conv1d(topi.nn.conv1d_nwc),
+                                    wrap_topi_schedule(topi.x86.schedule_conv1d_nwc),
+                                    name="conv1d_nwc.x86")
+    else:
+        raise ValueError("Unsupported conv1d layout {}".format(layout))
+    return strategy
+
+@dense_strategy.register("cpu")
+def dense_strategy_cpu(attrs, inputs, out_type, target):
+    """dense x86 strategy"""
+    strategy = _op.OpStrategy()
+    m, _ = inputs[0].shape
+    strategy.add_implementation(wrap_compute_dense(topi.x86.dense_nopack),
+                                wrap_topi_schedule(topi.x86.schedule_dense_nopack),
+                                name="dense_nopack.x86",
+                                plevel=10)
+    if "cblas" in target.libs:
+        strategy.add_implementation(wrap_compute_dense(topi.x86.dense_cblas),
+                                    wrap_topi_schedule(topi.x86.schedule_dense_cblas),
+                                    name="dense_cblas.x86",
+                                    plevel=5)
+    with SpecializedCondition(m >= 16):
+        # this implementation may not be well-optimized, so use plevel=8 for now.
+        strategy.add_implementation(wrap_compute_dense(topi.x86.dense_pack),
+                                    wrap_topi_schedule(topi.x86.schedule_dense_pack),
+                                    name="dense_pack.x86",
+                                    plevel=8)
+    return strategy
+
+@batch_matmul_strategy.register("cpu")
+def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
+    """batch_matmul x86 strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul),
+                                wrap_topi_schedule(topi.x86.schedule_batch_matmul),
+                                name="batch_matmul.x86",
+                                plevel=10)
+    if "cblas" in target.libs:
+        strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas),
+                                    wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas),
+                                    name="batch_matmul_cblas.x86",
+                                    plevel=5)
+    return strategy
+
+@schedule_sparse_dense.register("cpu")
+def schedule_sparse_dense_cpu(attrs, outs, target):
+    """schedule sparse_dense for x86"""
+    with target:
+        return topi.x86.schedule_sparse_dense(outs)
+
+@roi_align_strategy.register("cpu")
+def roi_align_strategy_cpu(attrs, inputs, out_type, target):
+    """roi_align x86 strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_roi_align(topi.x86.roi_align_nchw),
+                                wrap_topi_schedule(topi.generic.schedule_roi_align),
+                                name="roi_align.x86")
+    return strategy
+
+@bitserial_conv2d_strategy.register("cpu")
+def bitserial_conv2d_strategy_cpu(attrs, inputs, out_type, target):
+    """bitserial_conv2d x86 strategy"""
+    strategy = _op.OpStrategy()
+    layout = attrs.data_layout
+    if layout == "NCHW":
+        strategy.add_implementation(
+            wrap_compute_bitserial_conv2d(topi.x86.bitserial_conv2d_nchw),
+            wrap_topi_schedule(topi.x86.schedule_bitserial_conv2d_nchw),
+            name="bitserial_conv2d_nchw.x86")
+    elif layout == "NHWC":
+        strategy.add_implementation(
+            wrap_compute_bitserial_conv2d(topi.x86.bitserial_conv2d_nhwc),
+            wrap_topi_schedule(topi.x86.schedule_bitserial_conv2d_nhwc),
+            name="bitserial_conv2d_nhwc.x86")
+    else:
+        raise ValueError("Data layout {} not supported.".format(layout))
+    return strategy
+
+@bitserial_dense_strategy.register("cpu")
+def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target):
+    """bitserial_dense x86 strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_bitserial_dense(topi.x86.bitserial_dense),
+        wrap_topi_schedule(topi.x86.schedule_bitserial_dense),
+        name="bitserial_dense.x86")
+    return strategy
index f352839..16468e5 100644 (file)
 # pylint: disable=invalid-name, unused-argument
 """Faster R-CNN and Mask R-CNN operations."""
 import topi
-from topi.util import get_const_tuple, get_float_tuple, get_const_int
+from topi.util import get_const_tuple
 from .. import op as reg
+from .. import strategy
 from ..op import OpPattern
 
-
-@reg.register_compute("vision.roi_align")
-def compute_roi_align(attrs, inputs, _, target):
-    """Compute definition of roi_align"""
-    assert attrs.layout == "NCHW"
-    return [topi.vision.rcnn.roi_align_nchw(
-        inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size),
-        spatial_scale=attrs.spatial_scale, sample_ratio=attrs.sample_ratio)]
-
-@reg.register_schedule("vision.roi_align")
-def schedule_roi_align(_, outs, target):
-    """Schedule definition of roi_align"""
-    with target:
-        return topi.generic.vision.schedule_roi_align(outs)
-
+# roi_align
+reg.register_strategy("vision.roi_align", strategy.roi_align_strategy)
 reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)
 
+# roi_pool
 @reg.register_compute("vision.roi_pool")
-def compute_roi_pool(attrs, inputs, _, target):
+def compute_roi_pool(attrs, inputs, _):
     """Compute definition of roi_pool"""
     assert attrs.layout == "NCHW"
     return [topi.vision.rcnn.roi_pool_nchw(
         inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size),
         spatial_scale=attrs.spatial_scale)]
 
-@reg.register_schedule("vision.roi_pool")
-def schedule_roi_pool(_, outs, target):
-    """Schedule definition of roi_pool"""
-    with target:
-        return topi.generic.vision.schedule_roi_pool(outs)
-
+reg.register_schedule("vision.roi_pool", strategy.schedule_roi_pool)
 reg.register_pattern("vision.roi_pool", OpPattern.OUT_ELEMWISE_FUSABLE)
 
-@reg.register_compute("vision.proposal")
-def compute_proposal(attrs, inputs, _, target):
-    """Compute definition of proposal"""
-    scales = get_float_tuple(attrs.scales)
-    ratios = get_float_tuple(attrs.ratios)
-    feature_stride = attrs.feature_stride
-    threshold = attrs.threshold
-    rpn_pre_nms_top_n = attrs.rpn_pre_nms_top_n
-    rpn_post_nms_top_n = attrs.rpn_post_nms_top_n
-    rpn_min_size = attrs.rpn_min_size
-    iou_loss = bool(get_const_int(attrs.iou_loss))
-    with target:
-        return [
-            topi.vision.rcnn.proposal(inputs[0], inputs[1], inputs[2], scales, ratios,
-                                      feature_stride, threshold, rpn_pre_nms_top_n,
-                                      rpn_post_nms_top_n, rpn_min_size, iou_loss)
-        ]
-
-@reg.register_schedule("vision.proposal")
-def schedule_proposal(_, outs, target):
-    """Schedule definition of proposal"""
-    with target:
-        return topi.generic.schedule_proposal(outs)
-
+# proposal
+reg.register_strategy("vision.proposal", strategy.proposal_strategy)
 reg.register_pattern("vision.proposal", OpPattern.OPAQUE)
index 7de1180..6e2008a 100644 (file)
 """Definition of vision ops"""
 from __future__ import absolute_import
 
-import topi
-from topi.util import get_const_int, get_const_float, get_float_tuple
 from .. import op as reg
+from .. import strategy
 from ..op import OpPattern
 
-
-@reg.register_schedule("vision.multibox_prior")
-def schedule_multibox_prior(_, outs, target):
-    """Schedule definition of multibox_prior"""
-    with target:
-        return topi.generic.schedule_multibox_prior(outs)
-
-
-@reg.register_compute("vision.multibox_prior")
-def compute_multibox_prior(attrs, inputs, _, target):
-    """Compute definition of multibox_prior"""
-    sizes = get_float_tuple(attrs.sizes)
-    ratios = get_float_tuple(attrs.ratios)
-    steps = get_float_tuple(attrs.steps)
-    offsets = get_float_tuple(attrs.offsets)
-    clip = bool(get_const_int(attrs.clip))
-    return [
-        topi.vision.ssd.multibox_prior(inputs[0], sizes, ratios, steps,
-                                       offsets, clip)
-    ]
-
-
+# multibox_prior
+reg.register_strategy("vision.multibox_prior", strategy.multibox_prior_strategy)
 reg.register_pattern("vision.multibox_prior", OpPattern.OPAQUE)
 
 
 # multibox_transform_loc
-@reg.register_schedule("vision.multibox_transform_loc")
-def schedule_multibox_transform_loc(_, outs, target):
-    """Schedule definition of multibox_detection"""
-    with target:
-        return topi.generic.schedule_multibox_transform_loc(outs)
-
-
-@reg.register_compute("vision.multibox_transform_loc")
-def compute_multibox_transform_loc(attrs, inputs, _, target):
-    """Compute definition of multibox_detection"""
-    clip = bool(get_const_int(attrs.clip))
-    threshold = get_const_float(attrs.threshold)
-    variances = get_float_tuple(attrs.variances)
-    return topi.vision.ssd.multibox_transform_loc(
-        inputs[0], inputs[1], inputs[2], clip, threshold, variances)
-
-
+reg.register_strategy("vision.multibox_transform_loc", strategy.multibox_transform_loc_strategy)
 reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE)
-reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE)
 
 
 # Get counts of valid boxes
-@reg.register_schedule("vision.get_valid_counts")
-def schedule_get_valid_counts(_, outs, target):
-    """Schedule definition of get_valid_counts"""
-    with target:
-        return topi.generic.schedule_get_valid_counts(outs)
-
-
-@reg.register_compute("vision.get_valid_counts")
-def compute_get_valid_counts(attrs, inputs, _, target):
-    """Compute definition of get_valid_counts"""
-    score_threshold = get_const_float(attrs.score_threshold)
-    id_index = get_const_int(attrs.id_index)
-    score_index = get_const_int(attrs.score_index)
-    return topi.vision.get_valid_counts(inputs[0], score_threshold,
-                                        id_index, score_index)
-
+reg.register_strategy("vision.get_valid_counts", strategy.get_valid_counts_strategy)
 reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE)
 
 
 # non-maximum suppression
-@reg.register_schedule("vision.non_max_suppression")
-def schedule_nms(_, outs, target):
-    """Schedule definition of nms"""
-    with target:
-        return topi.generic.schedule_nms(outs)
-
-
-@reg.register_compute("vision.non_max_suppression")
-def compute_nms(attrs, inputs, _, target):
-    """Compute definition of nms"""
-    return_indices = bool(get_const_int(attrs.return_indices))
-    max_output_size = get_const_int(attrs.max_output_size)
-    iou_threshold = get_const_float(attrs.iou_threshold)
-    force_suppress = bool(get_const_int(attrs.force_suppress))
-    top_k = get_const_int(attrs.top_k)
-    coord_start = get_const_int(attrs.coord_start)
-    score_index = get_const_int(attrs.score_index)
-    id_index = get_const_int(attrs.id_index)
-    invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
-    return [
-        topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
-                                        iou_threshold, force_suppress, top_k,
-                                        coord_start, score_index, id_index,
-                                        return_indices, invalid_to_bottom)
-    ]
-
-
+reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy)
 reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE)
index 32fc62d..c58a7a3 100644 (file)
@@ -17,9 +17,9 @@
 #pylint: disable=invalid-name, unused-argument
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
-from ..op import  register_schedule, register_pattern
-from ..op import schedule_injective, OpPattern
+from ..op import register_pattern, OpPattern
+from ..op import register_injective_schedule
 
 # reorg
 register_pattern("vision.yolo_reorg", OpPattern.INJECTIVE)
-register_schedule("vision.yolo_reorg", schedule_injective)
+register_injective_schedule("vision.yolo_reorg")
index ba100d8..b77516d 100644 (file)
@@ -31,7 +31,7 @@ from .quantize import _forward_op
 
 
 @_reg.register_compute("relay.op.annotation.simulated_quantize")
-def simulated_quantize_compute(attrs, inputs, out_type, target):
+def simulated_quantize_compute(attrs, inputs, out_type):
     """Compiler for simulated_quantize."""
     assert len(inputs) == 4
     assert attrs.sign
@@ -52,11 +52,10 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):
     return [rdata]
 
 
-_reg.register_schedule("relay.op.annotation.simulated_quantize",
-                       _reg.schedule_injective)
+_reg.register_injective_schedule("relay.op.annotation.simulated_quantize")
 _reg.register_pattern("relay.op.annotation.simulated_quantize",
                       _reg.OpPattern.ELEMWISE)
-_reg.register_schedule("annotation.cast_hint", _reg.schedule_injective)
+_reg.register_injective_schedule("annotation.cast_hint")
 
 
 @register_relay_node
index 9aaefdf..d5a4d5f 100644 (file)
@@ -44,15 +44,18 @@ def conv_block(data, name, channels, kernel_size=(3, 3), strides=(1, 1),
 
 def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
                          kernel_size=(3, 3), downsample=False, padding=(1, 1),
-                         epsilon=1e-5, layout='NCHW'):
+                         epsilon=1e-5, layout='NCHW', dtype="float32"):
     """Helper function to get a separable conv block"""
     if downsample:
         strides = (2, 2)
     else:
         strides = (1, 1)
     # depthwise convolution + bn + relu
+    wshape = (depthwise_channels, 1) + kernel_size
+    weight = relay.var(name + "_weight", shape=wshape, dtype=dtype)
     conv1 = layers.conv2d(
         data=data,
+        weight=weight,
         channels=depthwise_channels,
         groups=depthwise_channels,
         kernel_size=kernel_size,
@@ -85,38 +88,41 @@ def mobile_net(num_classes=1000, data_shape=(1, 3, 224, 224),
     body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2, 2),
                       layout=layout)
     body = separable_conv_block(body, 'separable_conv_block_1',
-                                int(32*alpha), int(64*alpha), layout=layout)
+                                int(32*alpha), int(64*alpha), layout=layout,
+                                dtype=dtype)
     body = separable_conv_block(body, 'separable_conv_block_2',
                                 int(64*alpha), int(128*alpha), downsample=True,
-                                layout=layout)
+                                layout=layout, dtype=dtype)
     body = separable_conv_block(body, 'separable_conv_block_3',
-                                int(128*alpha), int(128*alpha), layout=layout)
+                                int(128*alpha), int(128*alpha), layout=layout,
+                                dtype=dtype)
     body = separable_conv_block(body, 'separable_conv_block_4',
                                 int(128*alpha), int(256*alpha), downsample=True,
-                                layout=layout)
+                                layout=layout, dtype=dtype)
     body = separable_conv_block(body, 'separable_conv_block_5',
-                                int(256*alpha), int(256*alpha), layout=layout)
+                                int(256*alpha), int(256*alpha), layout=layout,
+                                dtype=dtype)
     body = separable_conv_block(body, 'separable_conv_block_6',
                                 int(256*alpha), int(512*alpha), downsample=True,
-                                layout=layout)
+                                layout=layout, dtype=dtype)
     if is_shallow:
         body = separable_conv_block(body, 'separable_conv_block_7',
                                     int(512*alpha), int(1024*alpha),
-                                    downsample=True, layout=layout)
+                                    downsample=True, layout=layout, dtype=dtype)
         body = separable_conv_block(body, 'separable_conv_block_8',
                                     int(1024*alpha), int(1024*alpha),
-                                    downsample=True, layout=layout)
+                                    downsample=True, layout=layout, dtype=dtype)
     else:
         for i in range(7, 12):
             body = separable_conv_block(body, 'separable_conv_block_%d' % i,
                                         int(512*alpha), int(512*alpha),
-                                        layout=layout)
+                                        layout=layout, dtype=dtype)
         body = separable_conv_block(body, 'separable_conv_block_12',
                                     int(512*alpha), int(1024*alpha),
-                                    downsample=True, layout=layout)
+                                    downsample=True, layout=layout, dtype=dtype)
         body = separable_conv_block(body, 'separable_conv_block_13',
                                     int(1024*alpha), int(1024*alpha),
-                                    layout=layout)
+                                    layout=layout, dtype=dtype)
     pool = relay.nn.global_avg_pool2d(data=body, layout=layout)
     flatten = relay.nn.batch_flatten(data=pool)
     weight = relay.var('fc_weight')
index 13f280a..1936ff1 100644 (file)
@@ -184,6 +184,7 @@ def override_native_generic_func(func_name):
         fresult = decorate(fdefault, dispatch_func)
         fresult.fdefault = fdefault
         fresult.register = register
+        fresult.generic_func_node = generic_func_node
         return fresult
     return fdecorate
 
@@ -268,4 +269,5 @@ def generic_func(fdefault):
     fdecorate = decorate(fdefault, dispatch_func)
     fdecorate.register = register
     fdecorate.fdefault = fdefault
+    fdecorate.dispatch_dict = dispatch_dict
     return fdecorate
index 1580da3..5970315 100644 (file)
@@ -23,8 +23,8 @@ from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod,
 from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
 from tvm.tir import comm_reducer, min, max, sum
 
-from .schedule import Schedule, create_schedule
-from .tensor import Tensor
+from .schedule import Schedule, create_schedule, SpecializedCondition
+from .tensor import TensorSlice, Tensor
 from .tensor_intrin import decl_tensor_intrin
 from .tag import tag_scope
 from .operation import placeholder, compute, scan, extern, var, size_var
index d160f78..f8bbe09 100644 (file)
@@ -517,4 +517,39 @@ class Stage(Object):
         _ffi_api.StageOpenGL(self)
 
 
+@tvm._ffi.register_object
+class SpecializedCondition(Object):
+    """Specialized condition to enable op specialization."""
+    def __init__(self, conditions):
+        """Create a specialized condition.
+
+        .. note::
+            Conditions are represented in conjunctive joint form (CNF).
+            Each condition should be a simple expression, e.g., n > 16,
+            m % 8 == 0, etc., where n, m are tvm.Var that represents a
+            dimension in the tensor shape.
+
+        Parameters
+        ----------
+        conditions : List of tvm.Expr
+            List of conditions in conjunctive joint form (CNF).
+        """
+        if not isinstance(conditions, (list, _container.Array)):
+            conditions = [conditions]
+        self.__init_handle_by_constructor__(
+            _ffi_api.CreateSpecializedCondition, conditions)
+
+    @staticmethod
+    def current():
+        """Returns the current specialized condition"""
+        return _ffi_api.GetCurrentSpecialization()
+
+    def __enter__(self):
+        _ffi_api.EnterSpecializationScope(self)
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        _ffi_api.ExitSpecializationScope(self)
+
+
 tvm._ffi._init_api("schedule", __name__)
index ca11ffc..acf5f51 100644 (file)
@@ -964,3 +964,11 @@ class Let(PrimExprWithOp):
     def __init__(self, var, value, body):
         self.__init_handle_by_constructor__(
             _ffi_api.Let, var, value, body)
+
+
+@tvm._ffi.register_object
+class Any(PrimExpr):
+    """Any node.
+    """
+    def __init__(self):
+        self.__init_handle_by_constructor__(_ffi_api.Any)
index e5629e8..d0a7da9 100644 (file)
 namespace tvm {
 namespace relay {
 
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
 TVM_REGISTER_NODE_TYPE(CachedFuncNode);
 TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
 TVM_REGISTER_NODE_TYPE(CCacheValueNode);
 TVM_REGISTER_OBJECT_TYPE(CompileEngineNode);
 
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+  auto n = make_object<LoweredOutputNode>();
+  n->outputs = std::move(outputs);
+  n->implementation = std::move(impl);
+  data_ = std::move(n);
+}
+
 CCacheKey CCacheKeyNode::make(Function source_func, Target target) {
   auto n = make_object<CCacheKeyNode>();
   n->source_func = std::move(source_func);
@@ -108,9 +116,7 @@ class ScheduleGetter :
   explicit ScheduleGetter(Target target)
       : target_(target), device_copy_op_(Op::Get("device_copy")) {}
 
-  std::pair<te::Schedule, CachedFunc> Create(const Function& prim_func) {
-    static auto fschedule =
-        Op::GetAttr<FTVMSchedule>("FTVMSchedule");
+  CachedFunc Create(const Function& prim_func) {
     auto cache_node = make_object<CachedFuncNode>();
     cache_node->target = target_;
     for (Var param : prim_func->params) {
@@ -147,7 +153,6 @@ class ScheduleGetter :
     }
     cache_node->func_name = candidate_name;
 
-    CachedFunc cfunc(cache_node);
     CHECK(master_op_.defined());
     // Fusion over tupled results may leave identity relationships
     // between inputs and outputs, and those should not be scheduled.
@@ -161,15 +166,16 @@ class ScheduleGetter :
     te::Schedule schedule;
     // No need to register schedule for device copy op.
     if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) {
-      schedule =
-          fschedule[master_op_](master_attrs_, tensor_outs, target_);
+      CHECK(master_implementation_.defined());
+      schedule = master_implementation_.Schedule(master_attrs_, tensor_outs, target_);
       for (const auto& scalar : scalars_) {
         if (schedule->Contain(scalar)) {
           schedule[scalar].compute_inline();
         }
       }
     }
-    return std::make_pair(schedule, cfunc);
+    cache_node->schedule = std::move(schedule);
+    return CachedFunc(cache_node);
   }
 
   Array<te::Tensor> VisitExpr(const Expr& expr) {
@@ -208,16 +214,16 @@ class ScheduleGetter :
           LOG(FATAL) << "not handled";
           return tvm::PrimExpr();
         }
-      }, "compile_engine_const", topi::kBroadcast);
+    }, "compile_engine_const", topi::kBroadcast);
     scalars_.push_back(value->op);
     return {value};
   }
 
   Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
-    static auto fcompute =
-        Op::GetAttr<FTVMCompute>("FTVMCompute");
     static auto fpattern =
         Op::GetAttr<TOpPattern>("TOpPattern");
+    static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
+    CHECK(flower_call) << "relay.backend.lower_call is not registered.";
 
     Array<te::Tensor> inputs;
     int count_tuple = 0;
@@ -231,51 +237,37 @@ class ScheduleGetter :
     }
     if (count_tuple) {
       CHECK_EQ(call_node->args.size(), 1U)
-          << "Only allow function with a single tuple input";
-    }
-
-    // Prepare the call_node->checked_type(). For the call node inputs, we ensure that the shape is
-    // Int32. Following code ensures the same for the output as well.
-    // TODO(@icemelon): Support recursive tuple
-    Type call_node_type = call_node->checked_type();
-    if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
-      call_node_type = TensorType(GetShape(tt->shape), tt->dtype);
-    } else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
-      std::vector<Type> new_fields;
-      for (auto field : tuple_t->fields) {
-        if (const auto* tt = field.as<TensorTypeNode>()) {
-          new_fields.push_back(TensorType(GetShape(tt->shape), tt->dtype));
-        } else {
-          new_fields.push_back(field);
-        }
-      }
-      call_node_type = TupleType(new_fields);
+        << "Only allow function with a single tuple input";
     }
 
     CHECK(call_node->op.as<OpNode>())
-        << "Primitive function only allows call into primitive ops";
+      << "Primitive function only allows call into primitive ops";
     Op op = Downcast<Op>(call_node->op);
+
     Array<te::Tensor> outputs;
+    OpImplementation impl;
     // Skip fcompute for device copy operators as it is not registered.
     if (op == device_copy_op_) {
       const auto* copy_input = inputs[0].operator->();
       outputs.push_back(te::TensorNode::make(copy_input->shape, copy_input->dtype,
                                          te::Operation(), 0));
     } else {
-      outputs = fcompute[op](call_node->attrs, inputs,
-                             call_node_type, target_);
+      LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
+      outputs = lowered_out->outputs;
+      impl = lowered_out->implementation;
     }
 
     int op_pattern = fpattern[op];
     if (op_pattern >= kCommReduce) {
       CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce)
-          << "Two complicated op in a primitive function "
-          << " master=" << master_op_ << " current=" << op;
+        << "Two complicated op in a primitive function "
+        << " master=" << master_op_ << " current=" << op;
     }
     if (op_pattern >= master_op_pattern_) {
       master_op_ = op;
       master_attrs_ = call_node->attrs;
       master_op_pattern_ = op_pattern;
+      master_implementation_ = impl;
     }
     if (outputs.size() != 1) {
       const auto* tuple_type =
@@ -332,6 +324,7 @@ class ScheduleGetter :
   Op master_op_;
   Attrs master_attrs_;
   int master_op_pattern_{0};
+  OpImplementation master_implementation_;
   std::ostringstream readable_name_stream_;
   std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
   Array<te::Operation> scalars_;
@@ -677,8 +670,7 @@ class CompileEngineImpl : public CompileEngineNode {
    * \return Pair of schedule and cache.
    *  The funcs field in cache is not yet populated.
    */
-  std::pair<te::Schedule, CachedFunc> CreateSchedule(
-      const Function& source_func, const Target& target) {
+  CachedFunc CreateSchedule(const Function& source_func, const Target& target) {
     return ScheduleGetter(target).Create(source_func);
   }
 
@@ -713,9 +705,9 @@ class CompileEngineImpl : public CompileEngineNode {
     With<Target> target_scope(key->target);
 
     CHECK(!value->cached_func.defined());
-    auto spair = CreateSchedule(key->source_func, key->target);
+    auto cfunc = CreateSchedule(key->source_func, key->target);
     auto cache_node = make_object<CachedFuncNode>(
-        *(spair.second.operator->()));
+        *(cfunc.operator->()));
 
     // Skip lowering for device copy node.
     const Expr body = (key->source_func)->body;
@@ -735,11 +727,12 @@ class CompileEngineImpl : public CompileEngineNode {
     // lower the function
     if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
       cache_node->funcs = (*f)(
-          spair.first, all_args, cache_node->func_name, key->source_func);
+          cfunc->schedule, all_args, cache_node->func_name, key->source_func);
     } else {
       tvm::BuildConfig bcfg = BuildConfig::Create();
       std::unordered_map<te::Tensor, tir::Buffer> binds;
-      cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
+      cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name,
+                                     binds, bcfg);
     }
     value->cached_func = CachedFunc(cache_node);
     return value;
@@ -820,6 +813,11 @@ const CompileEngine& CompileEngine::Global() {
   return *inst;
 }
 
+TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
+.set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+  return LoweredOutput(outputs, impl);
+});
+
 TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
 .set_body_typed(CCacheKeyNode::make);
 
index 15ec2d6..2dbacf6 100644 (file)
@@ -30,6 +30,7 @@
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/transform.h>
+#include <tvm/relay/op_strategy.h>
 #include <string>
 #include <functional>
 
@@ -44,6 +45,28 @@ enum ShapeFuncParamState {
   kNeedBoth = 3,
 };
 
+struct LoweredOutputNode : public Object {
+  /*! \brief The outputs to the function */
+  tvm::Array<te::Tensor> outputs;
+  /*! \brief The implementation used to compute the output */
+  OpImplementation implementation;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("outputs", &outputs);
+    v->Visit("implementation", &implementation);
+  }
+
+  static constexpr const char* _type_key = "relay.LoweredOutput";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object);
+};
+
+class LoweredOutput : public ObjectRef {
+ public:
+  TVM_DLL LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode);
+};
+
 /*! \brief Node container to represent a cached function. */
 struct CachedFuncNode : public Object {
   /* \brief compiled target */
@@ -54,6 +77,8 @@ struct CachedFuncNode : public Object {
   tvm::Array<te::Tensor> inputs;
   /* \brief The outputs to the function */
   tvm::Array<te::Tensor> outputs;
+  /*! \brief The schedule to the function */
+  te::Schedule schedule;
   /*! \brief The lowered functions to support the function. */
   tvm::Array<tir::LoweredFunc> funcs;
   /*! \brief Parameter usage states in the shape function. */
@@ -64,6 +89,7 @@ struct CachedFuncNode : public Object {
     v->Visit("func_name", &func_name);
     v->Visit("inputs", &inputs);
     v->Visit("outputs", &outputs);
+    v->Visit("schedule", &schedule);
     v->Visit("funcs", &funcs);
     v->Visit("shape_func_param_states", &shape_func_param_states);
   }
diff --git a/src/relay/ir/op_strategy.cc b/src/relay/ir/op_strategy.cc
new file mode 100644 (file)
index 0000000..5ce6091
--- /dev/null
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/ir/op_strategy.cc
+ * \brief The Relay operator Strategy and related data structure.
+ */
+
+#include <tvm/relay/op_strategy.h>
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(OpImplementationNode);
+TVM_REGISTER_NODE_TYPE(OpSpecializationNode);
+TVM_REGISTER_NODE_TYPE(OpStrategyNode);
+
+Array<te::Tensor> OpImplementation::Compute(const Attrs& attrs,
+                                            const Array<te::Tensor>& inputs,
+                                            const Type& out_type) {
+  return (*this)->fcompute(attrs, inputs, out_type);
+}
+
+te::Schedule OpImplementation::Schedule(const Attrs& attrs,
+                                        const Array<te::Tensor> &outs,
+                                        const Target& target) {
+  return (*this)->fschedule(attrs, outs, target);
+}
+
+void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute,
+                                         tvm::relay::FTVMSchedule fschedule,
+                                         std::string name,
+                                         int plevel) {
+  auto n = make_object<OpImplementationNode>();
+  n->fcompute = fcompute;
+  n->fschedule = fschedule;
+  n->name = std::move(name);
+  n->plevel = plevel;
+  (*this)->implementations.push_back(OpImplementation(n));
+}
+
+void OpStrategy::AddImplementation(FTVMCompute fcompute,
+                                   FTVMSchedule fschedule,
+                                   std::string name,
+                                   int plevel) {
+  auto curr_cond = te::SpecializedCondition::Current();
+  auto self = this->operator->();
+  Array<OpSpecialization> specializations = self->specializations;
+  OpSpecialization op_spec;
+  for (OpSpecialization op_spec : specializations) {
+    if (op_spec->condition == curr_cond) {
+      op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel);
+      return;
+    }
+  }
+  ObjectPtr<OpSpecializationNode> n = make_object<OpSpecializationNode>();
+  n->condition = curr_cond;
+  op_spec = OpSpecialization(n);
+  op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel);
+  self->specializations.push_back(op_spec);
+}
+
+TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    OpImplementation imp = args[0];
+    Attrs attrs = args[1];
+    Array<te::Tensor> inputs = args[2];
+    Type out_type = args[3];
+    *rv = imp.Compute(attrs, inputs, out_type);
+});
+
+TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    OpImplementation imp = args[0];
+    Attrs attrs = args[1];
+    Array<te::Tensor> outs = args[2];
+    Target target = args[3];
+    *rv = imp.Schedule(attrs, outs, target);
+});
+
+TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    ObjectPtr<OpStrategyNode> n = make_object<OpStrategyNode>();
+    *rv = OpStrategy(n);
+});
+
+TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    OpStrategy strategy = args[0];
+    FTVMCompute compute = args[1];
+    FTVMSchedule schedule = args[2];
+    std::string name = args[3];
+    int plevel = args[4];
+    strategy.AddImplementation(compute, schedule, name, plevel);
+});
+
+}  // namespace relay
+}  // namespace tvm
index 6106b07..36f5923 100644 (file)
@@ -79,7 +79,7 @@ TVM_ADD_FILELINE)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
                        [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                          const Type& out_dtype) -> Array<te::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
@@ -105,7 +105,7 @@ TVM_ADD_FILELINE)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
                        [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                          const Type& out_dtype) -> Array<te::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
@@ -123,7 +123,7 @@ Mark the start of bitpacking.
                                ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
                        [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                          const Type& out_dtype) -> Array<te::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
@@ -140,7 +140,7 @@ Mark the end of bitpacking.
                                ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
                        [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                          const Type& out_dtype) -> Array<te::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
@@ -163,7 +163,7 @@ Mark a checkpoint for checkpointing memory optimization.
                                ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
                        [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                          const Type& out_dtype) -> Array<te::Tensor> {
                          Array<te::Tensor> outputs;
                          for (size_t i = 0; i < inputs.size(); ++i) {
                            outputs.push_back(topi::identity(inputs[i]));
@@ -184,7 +184,7 @@ Beginning of a region that is handled by a given compiler.
                                ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
                        [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                          const Type& out_dtype) -> Array<te::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
@@ -209,7 +209,7 @@ End of a region that is handled by a given compiler.
                                ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute",
                        [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                          const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                          const Type& out_dtype) -> Array<te::Tensor> {
                          return {topi::identity(inputs[0])};
                        });
 
index 14c0a01..a0f7fbf 100644 (file)
@@ -36,9 +36,8 @@ namespace relay {
 TVM_REGISTER_NODE_TYPE(DebugAttrs);
 
 Array<te::Tensor> DebugCompute(const Attrs& attrs,
-                           const Array<te::Tensor>& inputs,
-                           const Type& out_type,
-                           const Target& target) {
+                               const Array<te::Tensor>& inputs,
+                               const Type& out_type) {
   return Array<te::Tensor>{ topi::identity(inputs[0]) };
 }
 
index 076e3fc..d15099b 100644 (file)
@@ -83,7 +83,7 @@ RELAY_REGISTER_OP("memory.alloc_storage")
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
     .set_attr<FTVMCompute>("FTVMCompute",
                            [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                              const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                              const Type& out_dtype) -> Array<te::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
@@ -179,7 +179,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor")
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
     .set_attr<FTVMCompute>("FTVMCompute",
                            [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                              const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                              const Type& out_dtype) -> Array<te::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
@@ -228,7 +228,7 @@ RELAY_REGISTER_OP("memory.invoke_tvm_op")
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
     .set_attr<FTVMCompute>("FTVMCompute",
                            [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                              const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                              const Type& out_dtype) -> Array<te::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
@@ -252,7 +252,7 @@ RELAY_REGISTER_OP("memory.kill")
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
     .set_attr<FTVMCompute>("FTVMCompute",
                            [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                              const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                              const Type& out_dtype) -> Array<te::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
@@ -340,7 +340,7 @@ RELAY_REGISTER_OP("memory.shape_func")
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
     .set_attr<FTVMCompute>("FTVMCompute",
                            [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                              const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
+                              const Type& out_dtype) -> Array<te::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
index 6977ac9..cd9b5dd 100644 (file)
@@ -735,58 +735,6 @@ weight transformation in advance.
 .add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
 
 
-// Positional relay function to create conv2d winograd nnpack operator
-// used by frontend FFI.
-Expr MakeConv2DWinogradNNPACK(Expr data,
-                              Expr weight,
-                              Array<IndexExpr> strides,
-                              Array<IndexExpr> padding,
-                              Array<IndexExpr> dilation,
-                              int groups,
-                              IndexExpr channels,
-                              Array<IndexExpr> kernel_size,
-                              std::string data_layout,
-                              std::string kernel_layout,
-                              std::string out_layout,
-                              DataType out_dtype) {
-  auto attrs = make_object<Conv2DAttrs>();
-  attrs->strides = std::move(strides);
-  attrs->padding = std::move(padding);
-  attrs->dilation = std::move(dilation);
-  attrs->groups = groups;
-  attrs->channels = channels;
-  attrs->kernel_size = std::move(kernel_size);
-  attrs->data_layout = std::move(data_layout);
-  attrs->kernel_layout = std::move(kernel_layout);
-  attrs->out_layout = std::move(out_layout);
-  attrs->out_dtype = std::move(out_dtype);
-  static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_without_weight_transform");
-  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
-}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform")
-.set_body_typed(MakeConv2DWinogradNNPACK);
-
-RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
-.describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout.
-              This operator assumes the weight tensor is already pre-transformed by
-              nn.contrib_conv2d_winograd_nnpack_weight_transform.
-
-- **data**: Input is 4D array of shape  (batch_size, in_channels, height, width)
-- **weight**: Any shape
-            We do not check the shape for this input tensor. Since different backend
-            has different layout strategy.
-
-- **out**:  Output is 4D array of shape (batch_size, channels, out_height, out_width)
-)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel<Conv2DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
-
 // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
 TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
 
@@ -850,55 +798,6 @@ weight transformation in advance.
 
 // Positional relay function to create conv2d NCHWc operator
 // used by frontend FFI.
-Expr MakeConv2DNCHWcInt8(Expr data,
-                         Expr kernel,
-                         Array<IndexExpr> strides,
-                         Array<IndexExpr> padding,
-                         Array<IndexExpr> dilation,
-                         int groups,
-                         IndexExpr channels,
-                         Array<IndexExpr> kernel_size,
-                         std::string data_layout,
-                         std::string kernel_layout,
-                         std::string out_layout,
-                         DataType out_dtype) {
-  auto attrs = make_object<Conv2DAttrs>();
-  attrs->strides = std::move(strides);
-  attrs->padding = std::move(padding);
-  attrs->dilation = std::move(dilation);
-  attrs->groups = groups;
-  attrs->channels = channels;
-  attrs->kernel_size = std::move(kernel_size);
-  attrs->data_layout = std::move(data_layout);
-  attrs->kernel_layout = std::move(kernel_layout);
-  attrs->out_layout = std::move(out_layout);
-  attrs->out_dtype = std::move(out_dtype);
-  static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc_int8");
-  return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
-}
-
-TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc_int8")
-.set_body_typed(MakeConv2DNCHWcInt8);
-
-
-RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc_int8")
-.describe(R"code(Compute conv2d with NCHWc data layout with int8 inputs.
-- **data**: Input is 5D packed tensor.
-- **weight**: 7D packed tensor.
-
-- **out**:  Output is 5D packed tensor
-)code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DAttrs>()
-.set_num_inputs(2)
-.add_argument("data", "Tensor", "The input tensor.")
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2DNCHWcInt8", Conv2DWinogradRel<Conv2DAttrs>)
-.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
-        ConvInferCorrectLayout<Conv2DAttrs>);
-
-// Positional relay function to create conv2d NCHWc operator
-// used by frontend FFI.
 Expr MakeConv2DNCHWc(Expr data,
                      Expr kernel,
                      Array<IndexExpr> strides,
index 4061909..9ee84a0 100644 (file)
@@ -153,6 +153,16 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
       << " But got " << out_layout;
 
   Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
+  bool is_depthwise = false;
+  if (param->groups > 1) {
+    CHECK(weight && weight->shape.defined()) <<
+        "Weight shape must be specified when groups is greater than 1.";
+    Array<IndexExpr> wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape);
+    if (tvm::tir::Equal(param->groups, dshape_nchw[1]) &&
+        tvm::tir::Equal(param->groups, wshape_oihw[0])) {
+      is_depthwise = true;
+    }
+  }
 
   IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
   // infer weight if the kernel_size and channels are defined
@@ -161,9 +171,9 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     CHECK_EQ(param->dilation.size(), 2);
     Array<IndexExpr> wshape;
 
-    if (tvm::tir::Equal(param->channels, param->groups) && !tvm::tir::Equal(param->channels, 1)) {
+    if (is_depthwise) {
       // infer weight's shape for depthwise convolution
-      wshape = {{dshape_nchw[1], indexdiv(param->groups, dshape_nchw[1]), param->kernel_size[0],
+      wshape = {{dshape_nchw[1], indexdiv(param->channels, dshape_nchw[1]), param->kernel_size[0],
                  param->kernel_size[1]}};
     } else {
       wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0],
index ee4471a..10fd4d9 100644 (file)
@@ -93,8 +93,9 @@ RELAY_REGISTER_OP("nn.bias_add")
 .add_argument("bias", "1D Tensor", "Bias.")
 .set_support_level(1)
 .add_type_rel("BiasAdd", BiasAddRel)
-.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
-                                        const Type& out_type, const Target& target) {
+.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
+                                         const Array<te::Tensor>& inputs,
+                                         const Type& out_type) {
     const auto* param = attrs.as<BiasAddAttrs>();
     return tvm::Array<tvm::te::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
 });
@@ -234,8 +235,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
 .set_attr<FTVMCompute>(
   "FTVMCompute", [](const Attrs& attrs,
                     const Array<te::Tensor>& inputs,
-                    const Type& out_type,
-                    const Target& target) {
+                    const Type& out_type) {
     const auto* param = attrs.as<LeakyReluAttrs>();
     return Array<te::Tensor>{ topi::leaky_relu(inputs[0], param->alpha) };
 });
@@ -315,8 +315,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
 .set_attr<FTVMCompute>(
   "FTVMCompute", [](const Attrs& attrs,
                     const Array<te::Tensor>& inputs,
-                    const Type& out_type,
-                    const Target& target) {
+                    const Type& out_type) {
     const auto* param = attrs.as<PReluAttrs>();
     return Array<te::Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)};
 });
@@ -351,8 +350,7 @@ RELAY_REGISTER_OP("nn.softmax")
 .add_type_rel("Identity", IdentityRel)
 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
                                          const Array<te::Tensor>& inputs,
-                                         const Type& out_type,
-                                         const Target& target) {
+                                         const Type& out_type) {
   const auto* param = attrs.as<SoftmaxAttrs>();
   CHECK(param != nullptr);
   return Array<te::Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
@@ -385,8 +383,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
 .add_type_rel("Identity", IdentityRel)
 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
                                          const Array<te::Tensor>& inputs,
-                                         const Type& out_type,
-                                         const Target& target) {
+                                         const Type& out_type) {
   const auto* param = attrs.as<SoftmaxAttrs>();
   CHECK(param != nullptr);
   CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
@@ -462,8 +459,7 @@ Example::
 .set_attr<FTVMCompute>(
   "FTVMCompute", [](const Attrs& attrs,
                     const Array<te::Tensor>& inputs,
-                    const Type& out_type,
-                    const Target& target) {
+                    const Type& out_type) {
     return Array<te::Tensor>{ topi::nn::flatten(inputs[0]) };
 });
 
@@ -489,8 +485,7 @@ RELAY_REGISTER_OP("nn.relu")
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
                                          const Array<te::Tensor>& inputs,
-                                         const Type& out_type,
-                                         const Target& target) {
+                                         const Type& out_type) {
   return Array<te::Tensor>{ topi::relu(inputs[0], 0.0f) };
 });
 
index 94602ec..84a4940 100644 (file)
@@ -161,9 +161,8 @@ bool PadRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> PadCompute(const Attrs& attrs,
-                         const Array<te::Tensor>& inputs,
-                         const Type& out_type,
-                         const Target& target) {
+                             const Array<te::Tensor>& inputs,
+                             const Type& out_type) {
   const auto* param = attrs.as<PadAttrs>();
   CHECK(param != nullptr);
 
index 6775b09..77baae5 100644 (file)
@@ -164,9 +164,8 @@ bool Pool2DRel(const Array<Type>& types,
 
 template<typename AttrType, topi::nn::PoolType mode>
 Array<te::Tensor> Pool2DCompute(const Attrs& attrs,
-                            const Array<te::Tensor>& inputs,
-                            const Type& out_type,
-                            const Target& target) {
+                                const Array<te::Tensor>& inputs,
+                                const Type& out_type) {
   static const Layout kNCHW("NCHW");
   const auto* param = attrs.as<AttrType>();
   CHECK(param != nullptr);
@@ -331,9 +330,8 @@ bool GlobalPool2DRel(const Array<Type>& types,
 
 template<topi::nn::PoolType mode>
 Array<te::Tensor> GlobalPool2DCompute(const Attrs& attrs,
-                                  const Array<te::Tensor>& inputs,
-                                  const Type& out_type,
-                                  const Target& target) {
+                                      const Array<te::Tensor>& inputs,
+                                      const Type& out_type) {
   static const Layout kNCHW("NCHW");
   const auto* param = attrs.as<GlobalPool2DAttrs>();
   CHECK(param != nullptr);
@@ -465,9 +463,8 @@ bool AdaptivePool2DRel(const Array<Type>& types,
 
 template<topi::nn::PoolType mode>
 Array<te::Tensor> AdaptivePool2DCompute(const Attrs& attrs,
-                                    const Array<te::Tensor>& inputs,
-                                    const Type& out_type,
-                                    const Target& target) {
+                                        const Array<te::Tensor>& inputs,
+                                        const Type& out_type) {
   static const Layout kNCHW("NCHW");
   const auto* param = attrs.as<AdaptivePool2DAttrs>();
   CHECK(param != nullptr);
@@ -593,8 +590,9 @@ bool Pool2DGradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 }
 
 template <typename AttrType, topi::nn::PoolType mode>
-Array<te::Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
-                                const Type& out_type, const Target& target) {
+Array<te::Tensor> Pool2DGradCompute(const Attrs& attrs,
+                                    const Array<te::Tensor>& inputs,
+                                    const Type& out_type) {
   static const Layout kNCHW("NCHW");
   const auto* param = attrs.as<AttrType>();
   CHECK(param != nullptr);
@@ -793,9 +791,8 @@ bool Pool1DRel(const Array<Type>& types,
 
 template<typename AttrType, topi::nn::PoolType mode>
 Array<te::Tensor> Pool1DCompute(const Attrs& attrs,
-                            const Array<te::Tensor>& inputs,
-                            const Type& out_type,
-                            const Target& target) {
+                                const Array<te::Tensor>& inputs,
+                                const Type& out_type) {
   static const Layout kNCW("NCW");
   const auto* param = attrs.as<AttrType>();
   CHECK(param != nullptr);
@@ -985,9 +982,8 @@ bool Pool3DRel(const Array<Type>& types,
 
 template<typename AttrType, topi::nn::PoolType mode>
 Array<te::Tensor> Pool3DCompute(const Attrs& attrs,
-                            const Array<te::Tensor>& inputs,
-                            const Type& out_type,
-                            const Target& target) {
+                                const Array<te::Tensor>& inputs,
+                                const Type& out_type) {
   static const Layout kNCDHW("NCDHW");
   const auto* param = attrs.as<AttrType>();
   CHECK(param != nullptr);
index d1b915c..58221ae 100644 (file)
@@ -32,9 +32,8 @@ namespace relay {
 
 #define RELAY_BINARY_COMPUTE(FTOPI)                        \
   [] (const Attrs& attrs,                                  \
-      const Array<te::Tensor>& inputs,                         \
-      const Type& out_type,                                \
-      const Target& target) -> Array<te::Tensor> {             \
+      const Array<te::Tensor>& inputs,                     \
+      const Type& out_type) -> Array<te::Tensor> {         \
     CHECK_EQ(inputs.size(), 2U);                           \
     return {FTOPI(inputs[0], inputs[1])};                  \
   }                                                        \
index acbde0d..5e0795e 100644 (file)
@@ -176,7 +176,6 @@ template<typename F>
 Array<te::Tensor> ReduceCompute(const Attrs& attrs,
                             const Array<te::Tensor>& inputs,
                             const Type& out_type,
-                            const Target& target,
                             F f) {
   const ReduceAttrs* param = attrs.as<ReduceAttrs>();
   CHECK(param != nullptr);
@@ -321,10 +320,9 @@ bool ReduceRel(const Array<Type>& types,
 
 
 Array<te::Tensor> ArgMaxCompute(const Attrs& attrs,
-                            const Array<te::Tensor>& inputs,
-                            const Type& out_type,
-                            const Target& target) {
-  return ReduceCompute(attrs, inputs, out_type, target, topi::argmax);
+                                const Array<te::Tensor>& inputs,
+                                const Type& out_type) {
+  return ReduceCompute(attrs, inputs, out_type, topi::argmax);
 }
 
 
@@ -341,10 +339,9 @@ values over a given axis.
 
 
 Array<te::Tensor> ArgMinCompute(const Attrs& attrs,
-                            const Array<te::Tensor>& inputs,
-                            const Type& out_type,
-                            const Target& target) {
-  return ReduceCompute(attrs, inputs, out_type, target, topi::argmin);
+                                const Array<te::Tensor>& inputs,
+                                const Type& out_type) {
+  return ReduceCompute(attrs, inputs, out_type, topi::argmin);
 }
 
 RELAY_REGISTER_REDUCE_OP("argmin")
@@ -359,10 +356,9 @@ values over a given axis.
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
 Array<te::Tensor> SumCompute(const Attrs& attrs,
-                         const Array<te::Tensor>& inputs,
-                         const Type& out_type,
-                         const Target& target) {
-  return ReduceCompute(attrs, inputs, out_type, target, topi::sum);
+                             const Array<te::Tensor>& inputs,
+                             const Type& out_type) {
+  return ReduceCompute(attrs, inputs, out_type, topi::sum);
 }
 
 
@@ -393,10 +389,9 @@ Example::
 
 
 Array<te::Tensor> AllCompute(const Attrs& attrs,
-                         const Array<te::Tensor>& inputs,
-                         const Type& out_type,
-                         const Target& target) {
-  return ReduceCompute(attrs, inputs, out_type, target, topi::all);
+                             const Array<te::Tensor>& inputs,
+                             const Type& out_type) {
+  return ReduceCompute(attrs, inputs, out_type, topi::all);
 }
 
 
@@ -430,10 +425,9 @@ Example::
 
 
 Array<te::Tensor> AnyCompute(const Attrs& attrs,
-                         const Array<te::Tensor>& inputs,
-                         const Type& out_type,
-                         const Target& target) {
-  return ReduceCompute(attrs, inputs, out_type, target, topi::any);
+                             const Array<te::Tensor>& inputs,
+                             const Type& out_type) {
+  return ReduceCompute(attrs, inputs, out_type, topi::any);
 }
 
 
@@ -467,10 +461,9 @@ Example::
 
 
 Array<te::Tensor> MaxCompute(const Attrs& attrs,
-                         const Array<te::Tensor>& inputs,
-                         const Type& out_type,
-                         const Target& target) {
-  return ReduceCompute(attrs, inputs, out_type, target, topi::max);
+                             const Array<te::Tensor>& inputs,
+                             const Type& out_type) {
+  return ReduceCompute(attrs, inputs, out_type, topi::max);
 }
 
 RELAY_REGISTER_REDUCE_OP("max")
@@ -485,10 +478,9 @@ RELAY_REGISTER_REDUCE_OP("max")
 
 
 Array<te::Tensor> MinCompute(const Attrs& attrs,
-                         const Array<te::Tensor>& inputs,
-                         const Type& out_type,
-                         const Target& target) {
-  return ReduceCompute(attrs, inputs, out_type, target, topi::min);
+                             const Array<te::Tensor>& inputs,
+                             const Type& out_type) {
+  return ReduceCompute(attrs, inputs, out_type, topi::min);
 }
 
 
@@ -504,10 +496,9 @@ RELAY_REGISTER_REDUCE_OP("min")
 
 
 Array<te::Tensor> ProdCompute(const Attrs& attrs,
-                          const Array<te::Tensor>& inputs,
-                          const Type& out_type,
-                          const Target& target) {
-  return ReduceCompute(attrs, inputs, out_type, target, topi::prod);
+                              const Array<te::Tensor>& inputs,
+                              const Type& out_type) {
+  return ReduceCompute(attrs, inputs, out_type, topi::prod);
 }
 
 RELAY_REGISTER_REDUCE_OP("prod")
@@ -534,9 +525,8 @@ Example::
 
 
 Array<te::Tensor> MeanCompute(const Attrs& attrs,
-                          const Array<te::Tensor>& inputs,
-                          const Type& out_type,
-                          const Target& target) {
+                               const Array<te::Tensor>& inputs,
+                               const Type& out_type) {
   IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
   const ReduceAttrs* param = attrs.as<ReduceAttrs>();
   CHECK(param != nullptr);
@@ -546,7 +536,7 @@ Array<te::Tensor> MeanCompute(const Attrs& attrs,
                                  param->exclude)) {
     count *= inputs[0]->shape[i];
   }
-  auto res = ReduceCompute(attrs, inputs, out_type, target, topi::sum);
+  auto res = ReduceCompute(attrs, inputs, out_type, topi::sum);
   return {topi::divide(res[0], count)};
 }
 
@@ -599,9 +589,8 @@ bool VarianceRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> VarianceCompute(const Attrs& attrs,
-                              const Array<te::Tensor>& inputs,
-                              const Type& out_type,
-                              const Target& target) {
+                                  const Array<te::Tensor>& inputs,
+                                  const Type& out_type) {
   IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
   const ReduceAttrs* param = attrs.as<ReduceAttrs>();
   CHECK(param != nullptr);
@@ -615,7 +604,7 @@ Array<te::Tensor> VarianceCompute(const Attrs& attrs,
   }
   std::vector<Integer> expand_shape;
   auto sq_diff = topi::power(topi::subtract(data, mean), 2);
-  auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, target, topi::sum)[0], count);
+  auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, topi::sum)[0], count);
 
   return {var};
 }
index 969912f..fa96d7f 100644 (file)
@@ -66,9 +66,8 @@ bool CastRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> CastCompute(const Attrs& attrs,
-                          const Array<te::Tensor>& inputs,
-                          const Type& out_type,
-                          const Target& target) {
+                              const Array<te::Tensor>& inputs,
+                              const Type& out_type) {
   const CastAttrs *param = attrs.as<CastAttrs>();
   CHECK(param != nullptr);
   DataType dtype = param->dtype;
@@ -126,9 +125,8 @@ bool CastLikeRel(const Array<Type>& types,
 
 
 Array<te::Tensor> CastLikeCompute(const Attrs& attrs,
-                              const Array<te::Tensor>& inputs,
-                              const Type& out_type,
-                              const Target& target) {
+                                  const Array<te::Tensor>& inputs,
+                                  const Type& out_type) {
   return { topi::cast(inputs[0], inputs[1]->dtype) };
 }
 
@@ -156,8 +154,9 @@ RELAY_REGISTER_OP("cast_like")
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
 
 
-Array<te::Tensor> ReinterpretCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
-                                 const Type& out_type, const Target& target) {
+Array<te::Tensor> ReinterpretCompute(const Attrs& attrs,
+                                     const Array<te::Tensor>& inputs,
+                                     const Type& out_type) {
   const CastAttrs* param = attrs.as<CastAttrs>();
   CHECK(param != nullptr);
   DataType dtype = param->dtype;
@@ -231,9 +230,8 @@ bool ExpandDimsRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> ExpandDimsCompute(const Attrs& attrs,
-                                const Array<te::Tensor>& inputs,
-                                const Type& out_type,
-                                const Target& target) {
+                                    const Array<te::Tensor>& inputs,
+                                    const Type& out_type) {
   const ExpandDimsAttrs *param = attrs.as<ExpandDimsAttrs>();
   CHECK(param != nullptr);
   return { topi::expand_dims(inputs[0], param->axis, param->num_newaxis) };
@@ -270,9 +268,8 @@ RELAY_REGISTER_OP("expand_dims")
 TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
 
 Array<te::Tensor> ConcatenateCompute(const Attrs& attrs,
-                          const Array<te::Tensor>& inputs,
-                          const Type& out_type,
-                          const Target& target) {
+                                     const Array<te::Tensor>& inputs,
+                                     const Type& out_type) {
   const ConcatenateAttrs *param = attrs.as<ConcatenateAttrs>();
   CHECK(param != nullptr);
   return { topi::concatenate(inputs, param->axis) };
@@ -413,9 +410,8 @@ bool StackRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> StackCompute(const Attrs& attrs,
-                           const Array<te::Tensor>& inputs,
-                           const Type& out_type,
-                           const Target& target) {
+                               const Array<te::Tensor>& inputs,
+                               const Type& out_type) {
   const StackAttrs *param = attrs.as<StackAttrs>();
   CHECK(param != nullptr);
   return { topi::stack(inputs, param->axis) };
@@ -505,9 +501,8 @@ bool TransposeRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> TransposeCompute(const Attrs& attrs,
-                               const Array<te::Tensor>& inputs,
-                               const Type& out_type,
-                               const Target& target) {
+                                   const Array<te::Tensor>& inputs,
+                                   const Type& out_type) {
   const auto* param = attrs.as<TransposeAttrs>();
   CHECK(param != nullptr);
   return Array<te::Tensor>{ topi::transpose(inputs[0], param->axes) };
@@ -688,9 +683,8 @@ bool ReshapeRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> ReshapeCompute(const Attrs& attrs,
-                             const Array<te::Tensor>& inputs,
-                             const Type& out_type,
-                             const Target& target) {
+                                 const Array<te::Tensor>& inputs,
+                                 const Type& out_type) {
   const auto* out_ttype = out_type.as<TensorTypeNode>();
   CHECK(out_ttype != nullptr);
   Array<IndexExpr> newshape;
@@ -923,9 +917,8 @@ bool TakeRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> TakeCompute(const Attrs& attrs,
-                          const Array<te::Tensor>& inputs,
-                          const Type& out_type,
-                          const Target& target) {
+                              const Array<te::Tensor>& inputs,
+                              const Type& out_type) {
   const auto* param = attrs.as<TakeAttrs>();
   CHECK(param != nullptr);
   if (!param->axis.defined()) {
@@ -1010,9 +1003,8 @@ bool FullRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> FullCompute(const Attrs& attrs,
-                          const Array<te::Tensor>& inputs,
-                          const Type& out_type,
-                          const Target& target) {
+                              const Array<te::Tensor>& inputs,
+                              const Type& out_type) {
   const auto* out_ttype = out_type.as<TensorTypeNode>();
   return { topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]()) };
 }
@@ -1118,9 +1110,8 @@ bool FullLikeRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> FullLikeCompute(const Attrs& attrs,
-                              const Array<te::Tensor>& inputs,
-                              const Type& out_type,
-                              const Target& target) {
+                                  const Array<te::Tensor>& inputs,
+                                  const Type& out_type) {
   return { topi::full_like(inputs[0], inputs[1]()) };
 }
 
@@ -1230,9 +1221,8 @@ inline te::Tensor DynamicArange(const te::Tensor& start,
 }
 
 Array<te::Tensor> ArangeCompute(const Attrs& attrs,
-                                 const Array<te::Tensor>& inputs,
-                                 const Type& out_type,
-                                 const Target& target) {
+                                const Array<te::Tensor>& inputs,
+                                const Type& out_type) {
   const ArangeAttrs* param = attrs.as<ArangeAttrs>();
   te::Tensor start = inputs[0];
   te::Tensor stop =  inputs[1];
@@ -1325,9 +1315,8 @@ bool RepeatRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> RepeatCompute(const Attrs& attrs,
-                            const Array<te::Tensor>& inputs,
-                            const Type& out_type,
-                            const Target& target) {
+                                const Array<te::Tensor>& inputs,
+                                const Type& out_type) {
   const RepeatAttrs *param = attrs.as<RepeatAttrs>();
   CHECK(param != nullptr);
   return { topi::repeat(inputs[0], param->repeats, param->axis) };
@@ -1436,9 +1425,8 @@ bool TileRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> TileCompute(const Attrs& attrs,
-                          const Array<te::Tensor>& inputs,
-                          const Type& out_type,
-                          const Target& target) {
+                              const Array<te::Tensor>& inputs,
+                              const Type& out_type) {
   const TileAttrs *param = attrs.as<TileAttrs>();
   CHECK(param != nullptr);
   return { topi::tile(inputs[0], param->reps) };
@@ -1497,9 +1485,8 @@ bool ReverseRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> ReverseCompute(const Attrs& attrs,
-                             const Array<te::Tensor>& inputs,
-                             const Type& out_type,
-                             const Target& target) {
+                                 const Array<te::Tensor>& inputs,
+                                 const Type& out_type) {
   const ReverseAttrs *param = attrs.as<ReverseAttrs>();
   CHECK(param != nullptr);
   return { topi::flip(inputs[0], param->axis) };
@@ -1571,9 +1558,8 @@ Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) {
 }
 
 Array<te::Tensor> WhereCompute(const Attrs& attrs,
-                           const Array<te::Tensor>& inputs,
-                           const Type& out_type,
-                           const Target& target) {
+                               const Array<te::Tensor>& inputs,
+                               const Type& out_type) {
   return { topi::where(inputs[0], inputs[1], inputs[2]) };
 }
 
@@ -1688,9 +1674,8 @@ bool SqueezeRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> SqueezeCompute(const Attrs& attrs,
-                             const Array<te::Tensor>& inputs,
-                             const Type& out_type,
-                             const Target& target) {
+                                 const Array<te::Tensor>& inputs,
+                                 const Type& out_type) {
   const SqueezeAttrs *param = attrs.as<SqueezeAttrs>();
   CHECK(param != nullptr);
   return { topi::squeeze(inputs[0], param->axis) };
@@ -1729,9 +1714,8 @@ Expr MakeCollapseSumLike(Expr data,
 }
 
 Array<te::Tensor> CollapseSumLikeCompute(const Attrs& attrs,
-                                     const Array<te::Tensor>& inputs,
-                                     const Type& out_type,
-                                     const Target& target) {
+                                         const Array<te::Tensor>& inputs,
+                                         const Type& out_type) {
   const auto* out_ttype = out_type.as<TensorTypeNode>();
   CHECK(out_ttype != nullptr);
   return { topi::collapse_sum(inputs[0], out_ttype->shape) };
@@ -1774,9 +1758,8 @@ Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) {
 }
 
 Array<te::Tensor> BroadCastToCompute(const Attrs& attrs,
-                                 const Array<te::Tensor>& inputs,
-                                 const Type& out_type,
-                                 const Target& target) {
+                                     const Array<te::Tensor>& inputs,
+                                     const Type& out_type) {
   auto ioattrs = attrs.as<InitOpAttrs>();
   CHECK(ioattrs != nullptr);
   return { topi::broadcast_to(inputs[0], ioattrs->shape) };
@@ -1812,9 +1795,8 @@ Expr MakeBroadCastToLike(Expr data,
 }
 
 Array<te::Tensor> BroadCastToLikeCompute(const Attrs& attrs,
-                                     const Array<te::Tensor>& inputs,
-                                     const Type& out_type,
-                                     const Target& target) {
+                                         const Array<te::Tensor>& inputs,
+                                         const Type& out_type) {
   const auto* out_ttype = out_type.as<TensorTypeNode>();
   CHECK(out_ttype != nullptr);
   return { topi::broadcast_to(inputs[0], out_ttype->shape) };
@@ -2019,9 +2001,8 @@ Expr MakeStridedSlice(Expr data,
 }
 
 Array<te::Tensor> StridedSliceCompute(const Attrs& attrs,
-                                  const Array<te::Tensor>& inputs,
-                                  const Type& out_type,
-                                  const Target& target) {
+                                      const Array<te::Tensor>& inputs,
+                                      const Type& out_type) {
   const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
   CHECK(param != nullptr);
   return Array<te::Tensor>{
@@ -2176,9 +2157,8 @@ bool SplitRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> SplitCompute(const Attrs& attrs,
-                           const Array<te::Tensor>& inputs,
-                           const Type& out_type,
-                           const Target& target) {
+                               const Array<te::Tensor>& inputs,
+                               const Type& out_type) {
   const auto param = attrs.as<SplitAttrs>();
   CHECK(param != nullptr);
 
@@ -2305,9 +2285,8 @@ Expr MakeSliceLike(Expr data,
 }
 
 Array<te::Tensor> SliceLikeCompute(const Attrs& attrs,
-                               const Array<te::Tensor>& inputs,
-                               const Type& out_type,
-                               const Target& target) {
+                                   const Array<te::Tensor>& inputs,
+                                   const Type& out_type) {
   const auto* param = attrs.as<SliceLikeAttrs>();
   CHECK(param != nullptr);
   Array<IndexExpr> src_shape = inputs[0]->shape;
@@ -2371,9 +2350,8 @@ RELAY_REGISTER_OP("slice_like")
 TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
 
 Array<te::Tensor> LayoutTransformCompute(const Attrs& attrs,
-                                     const Array<te::Tensor>& inputs,
-                                     const Type& out_type,
-                                     const Target& target) {
+                                         const Array<te::Tensor>& inputs,
+                                         const Type& out_type) {
   const auto* param = attrs.as<LayoutTransformAttrs>();
   CHECK(param != nullptr);
   return Array<te::Tensor>{
@@ -2504,9 +2482,8 @@ bool GatherNDRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> GatherNDCompute(const Attrs& attrs,
-                              const Array<te::Tensor>& inputs,
-                              const Type& out_type,
-                              const Target& target) {
+                                  const Array<te::Tensor>& inputs,
+                                  const Type& out_type) {
   return { topi::gather_nd(inputs[0], inputs[1]) };
 }
 
@@ -2558,9 +2535,8 @@ bool SequenceMaskRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> SequenceMaskCompute(const Attrs& attrs,
-                                  const Array<te::Tensor>& inputs,
-                                  const Type& out_type,
-                                  const Target& target) {
+                                      const Array<te::Tensor>& inputs,
+                                      const Type& out_type) {
   const auto* param = attrs.as<SequenceMaskAttrs>();
   CHECK(param != nullptr);
   return Array<te::Tensor>{
@@ -2671,9 +2647,8 @@ bool OneHotRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> OneHotCompute(const Attrs& attrs,
-                            const Array<te::Tensor>& inputs,
-                            const Type& out_type,
-                            const Target& target) {
+                                const Array<te::Tensor>& inputs,
+                                const Type& out_type) {
   const auto* param = attrs.as<OneHotAttrs>();
   CHECK(param != nullptr);
   return Array<te::Tensor> {
index 7f6db50..caa6451 100644 (file)
@@ -34,9 +34,8 @@ namespace relay {
 
 #define RELAY_UNARY_COMPUTE(FTOPI)                      \
   [] (const Attrs& attrs,                               \
-      const Array<te::Tensor>& inputs,                      \
-      const Type& out_type,                             \
-      const Target& target) -> Array<te::Tensor> {          \
+      const Array<te::Tensor>& inputs,                  \
+      const Type& out_type) -> Array<te::Tensor> {      \
     return {FTOPI(inputs[0])};                          \
   }                                                     \
 
@@ -302,9 +301,8 @@ bool ShapeOfRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> ShapeOfCompute(const Attrs& attrs,
-                             const Array<te::Tensor>& inputs,
-                             const Type& out_type,
-                             const Target& target) {
+                                 const Array<te::Tensor>& inputs,
+                                 const Type& out_type) {
   CHECK_EQ(inputs.size(), 1);
   const auto* param = attrs.as<ShapeOfAttrs>();
   CHECK(param != nullptr);
@@ -353,9 +351,8 @@ bool NdarraySizeRel(const Array<Type>& types,
 }
 
 Array<te::Tensor> NdarraySizeCompute(const Attrs& attrs,
-                          const Array<te::Tensor>& inputs,
-                          const Type& out_type,
-                          const Target& target) {
+                                     const Array<te::Tensor>& inputs,
+                                     const Type& out_type) {
   CHECK_EQ(inputs.size(), 1);
   const auto* param = attrs.as<NdarraySizeAttrs>();
   CHECK(param != nullptr);
index 9c4a285..7d15271 100644 (file)
@@ -83,8 +83,7 @@ Its function is mostly shape transform.")doc" TVM_ADD_FILELINE)
 .add_type_rel("YoloReorg", YoloReorgRel)
 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
                                          const Array<te::Tensor>& inputs,
-                                         const Type& out_type,
-                                         const Target& target) {
+                                         const Type& out_type) {
   const auto* params = attrs.as<YoloReorgAttrs>();
   CHECK(params != nullptr);
   return Array<te::Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
index 0cc3ff0..fe88625 100644 (file)
@@ -83,7 +83,10 @@ class AlterTransformMemorizer : public TransformMemorizer {
         auto ttype = expr->type_as<TensorTypeNode>();
         tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
       }
-      Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos);
+      // TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes.
+      //   Probably we need to disable the AlterOpLayout when compiling dynamic models.
+      Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos,
+                                             ref_call->checked_type());
       if (altered_value.defined()) {
         new_e = altered_value;
         modified = true;
index d3b448d..7a2e05a 100644 (file)
 /*!
  * \file schedule_lang.cc
  */
+#include <dmlc/thread_local.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/te/schedule.h>
 #include <tvm/te/operation.h>
+#include <stack>
 #include <unordered_set>
 #include "graph.h"
 
@@ -787,6 +789,53 @@ IterVarRelation SingletonNode::make(IterVar iter) {
   return IterVarRelation(n);
 }
 
+SpecializedCondition::SpecializedCondition(Array<PrimExpr> conditions) {
+  ObjectPtr<SpecializedConditionNode> n = make_object<SpecializedConditionNode>();
+  n->clauses = std::move(conditions);
+  data_ = std::move(n);
+}
+
+/*! \brief Entry to hold the SpecializedCondition context stack. */
+struct TVMSpecializationThreadLocalEntry {
+  /*! \brief The current specialized condition */
+  std::stack<SpecializedCondition> condition_stack;
+};
+
+/*! \brief Thread local store to hold the Target context stack. */
+typedef dmlc::ThreadLocalStore<TVMSpecializationThreadLocalEntry> TVMSpecializationThreadLocalStore;
+
+void SpecializedCondition::EnterWithScope() {
+  TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
+  entry->condition_stack.push(*this);
+}
+
+void SpecializedCondition::ExitWithScope() {
+  TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
+  CHECK(!entry->condition_stack.empty());
+  CHECK(entry->condition_stack.top().same_as(*this));
+  entry->condition_stack.pop();
+}
+
+SpecializedCondition SpecializedCondition::Current() {
+  TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
+  SpecializedCondition cond;
+  if (entry->condition_stack.size() > 0) {
+    cond = entry->condition_stack.top();
+  }
+  return cond;
+}
+
+class SpecializedCondition::Internal {
+ public:
+  static void EnterScope(SpecializedCondition cond) {
+    cond.EnterWithScope();
+  }
+
+  static void ExitScope(SpecializedCondition cond) {
+    cond.ExitWithScope();
+  }
+};
+
 TVM_REGISTER_NODE_TYPE(StageNode);
 TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
 TVM_REGISTER_NODE_TYPE(SplitNode);
@@ -794,6 +843,7 @@ TVM_REGISTER_NODE_TYPE(FuseNode);
 TVM_REGISTER_NODE_TYPE(RebaseNode);
 TVM_REGISTER_NODE_TYPE(SingletonNode);
 TVM_REGISTER_NODE_TYPE(ScheduleNode);
+TVM_REGISTER_NODE_TYPE(SpecializedConditionNode);
 
 // Printer
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -848,7 +898,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) {
     auto* op = static_cast<const ScheduleNode*>(node.get());
     p->stream << "schedule(" << op << ")";
-  });
+})
+.set_dispatch<SpecializedConditionNode>([](const ObjectRef& node, ReprPrinter* p) {
+    auto* op = static_cast<const SpecializedConditionNode*>(node.get());
+    p->stream << "specialized_condition(";
+    p->Print(op->clauses);
+    p->stream << ')';
+});
 
 
 TVM_REGISTER_GLOBAL("te.CreateSchedule")
@@ -962,5 +1018,22 @@ TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
 
 TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
 .set_body_method(&Schedule::rfactor);
+
+TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition")
+.set_body_typed([](Array<PrimExpr> condition) {
+    return SpecializedCondition(condition);
+});
+
+TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    *ret = SpecializedCondition::Current();
+});
+
+TVM_REGISTER_GLOBAL("te.EnterSpecializationScope")
+.set_body_typed(SpecializedCondition::Internal::EnterScope);
+
+TVM_REGISTER_GLOBAL("te.ExitSpecializationScope")
+.set_body_typed(SpecializedCondition::Internal::ExitScope);
+
 }  // namespace te
 }  // namespace tvm
index c274083..b9a8f8f 100644 (file)
 #include <tvm/relay/type.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/transform.h>
+#include <tvm/relay/op_strategy.h>
+#include <tvm/relay/op_attr_types.h>
+#include <topi/broadcast.h>
 #include <topi/generic/injective.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/module.h>
 #include <tvm/runtime/registry.h>
 
-TVM_REGISTER_GLOBAL("test.sch")
-.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) {
-  *rv = topi::generic::schedule_injective(args[0], args[1]);
-  });
+using namespace tvm;
+using namespace tvm::relay;
+
+TVM_REGISTER_GLOBAL("test.strategy")
+.set_body_typed([](const Attrs& attrs, const Array<te::Tensor>& inputs,
+                   const Type& out_type, const Target& target) {
+    FTVMCompute fcompute = [](const Attrs& attrs,
+                              const Array<te::Tensor>& inputs,
+                              const Type& out_type) -> Array<te::Tensor> {
+        CHECK_EQ(inputs.size(), 2U);
+        return {topi::add(inputs[0], inputs[1])};
+    };
+    FTVMSchedule fschedule = [](const Attrs& attrs,
+                                const Array<te::Tensor>& outs,
+                                const Target& target) {
+        With<Target> target_scope(target);
+        return topi::generic::schedule_injective(target, outs);
+    };
+
+    auto n = make_object<OpStrategyNode>();
+    auto strategy = tvm::relay::OpStrategy(std::move(n));
+    strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10);
+    return strategy;
+});
+
+TVM_REGISTER_GLOBAL("relay.backend.lower_call")
+.set_body_typed([](const relay::Call& call, const Array<te::Tensor>& inputs,
+                   const Target& target) {
+    static auto fstrategy = Op::GetAttr<relay::FTVMStrategy>("FTVMStrategy");
+    Op op = Downcast<Op>(call->op);
+    auto out_type = call->checked_type();
+    OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target);
+    auto impl = strategy->specializations[0]->implementations[0];
+    auto outs = impl.Compute(call->attrs, inputs, out_type);
+    auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput");
+    if (!f) {
+      LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered";
+    }
+    return (*f)(outs, impl);
+});
 
 TEST(Relay, BuildModule) {
-  using namespace tvm;
   auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
   auto a = relay::VarNode::make("a", tensor_type);
   auto b = relay::VarNode::make("b", tensor_type);
@@ -59,14 +97,15 @@ TEST(Relay, BuildModule) {
   }
   // get schedule
   auto reg = tvm::runtime::Registry::Get("relay.op._Register");
-  auto s_i = tvm::runtime::Registry::Get("test.sch");
   if (!reg) {
     LOG(FATAL) << "no _Register";
   }
-  if (!s_i) {
-    LOG(FATAL) << "no _Register";
+  auto fs = tvm::runtime::Registry::Get("test.strategy");
+  if (!fs) {
+    LOG(FATAL) << "No test_strategy registered.";
   }
-  (*reg)("add", "FTVMSchedule", *s_i, 10);
+  auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs);
+  (*reg)("add", "FTVMStrategy", fgeneric, 10);
   // build
   auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
   tvm::runtime::Module build_mod = (*pfb)();
index 8a6ceb8..504f700 100644 (file)
@@ -852,17 +852,22 @@ def test_forward_slice():
 
 
 def test_forward_convolution():
-    def verify(data_shape, kernel_size, stride, pad, num_filter):
-        weight_shape=(num_filter, data_shape[1],) + kernel_size
+    def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False):
+        if is_depthwise:
+            groups = data_shape[1]
+            weight_shape=(data_shape[1], num_filter // groups,) + kernel_size
+        else:
+            groups = 1
+            weight_shape=(num_filter, data_shape[1],) + kernel_size
         x = np.random.uniform(size=data_shape).astype("float32")
         weight = np.random.uniform(size=weight_shape).astype("float32")
         bias = np.random.uniform(size=num_filter).astype("float32")
         ref_res = mx.nd.Convolution(data=mx.nd.array(x), weight=mx.nd.array(weight),
                                     bias=mx.nd.array(bias), kernel=kernel_size, stride=stride,
-                                    pad=pad, num_filter=num_filter)
+                                    pad=pad, num_filter=num_filter, num_group=groups)
         mx_sym = mx.sym.Convolution(mx.sym.var("x"), mx.sym.var("weight"), mx.sym.var("bias"),
                                     kernel=kernel_size, stride=stride,
-                                    pad=pad, num_filter=num_filter)
+                                    pad=pad, num_filter=num_filter, num_group=groups)
         shape_dict = {"x": x.shape, "weight": weight.shape, "bias": bias.shape}
         mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
         for target, ctx in ctx_list():
@@ -879,6 +884,8 @@ def test_forward_convolution():
     verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
     verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
     verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
+    verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=8,
+           is_depthwise=True)
 
 def test_forward_deconvolution():
     def verify(data_shape, kernel_size, stride, pad, num_filter):
index fec7d3e..99f8b47 100644 (file)
@@ -25,7 +25,7 @@ import tvm
 from tvm import autotvm
 from tvm.autotvm.tuner import RandomTuner
 
-@autotvm.template
+@autotvm.register_customized_task("testing/conv2d_no_batching")
 def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
     """An example template for testing"""
     assert N == 1, "Only consider batch_size = 1 in this template"
@@ -114,7 +114,7 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
 
 def get_sample_task(target=tvm.target.cuda(), target_host=None):
     """return a sample task for testing"""
-    task = autotvm.task.create(conv2d_no_batching,
+    task = autotvm.task.create("testing/conv2d_no_batching",
                                args=(1, 7, 7, 512, 512, 3, 3),
                                target=target, target_host=target_host)
     return task, target
index 3e392a8..24176e4 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
+import pytest
 
 import tvm
 from tvm import relay
@@ -384,6 +385,8 @@ def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation
         assert result.asnumpy().shape == ref_out_shape, \
             "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
 
+# TODO(@kevinthesun): Need to fix the compute in conv2d_NCHWc to support any
+@pytest.mark.skip
 def test_any_conv2d_NCHWc():
     verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (1, 1),
                             "NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 224, 224, 8))
index 8f550d8..e555caa 100644 (file)
@@ -39,25 +39,28 @@ def test_task_extraction():
     target = 'llvm'
     mod_list = []
     params_list = []
+    conv2d = relay.op.get("nn.conv2d")
+    conv2d_transpose = relay.op.get("nn.conv2d_transpose")
+    dense = relay.op.get("nn.dense")
 
     mod, params, _ = get_network('resnet-18', batch_size=1)
     tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d,))
+                                              ops=(conv2d,))
     assert len(tasks) == 12
     tasks = autotvm.task.extract_from_program(mod, target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d,))
+                                              ops=(conv2d,))
     assert len(tasks) == 12
 
     mod, params, _ = get_network('resnet-18', batch_size=1)
     tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                               params=params,
-                                              ops=(relay.op.nn.dense,))
+                                              ops=(dense,))
     assert len(tasks) == 1
     tasks = autotvm.task.extract_from_program(mod, target=target,
                                               params=params,
-                                              ops=(relay.op.nn.dense,))
+                                              ops=(dense,))
     assert len(tasks) == 1
 
     mod, params, _ = get_network('resnet-18', batch_size=1)
@@ -65,11 +68,14 @@ def test_task_extraction():
     params_list.append(params)
     tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense))
+                                              ops=(conv2d, dense))
     assert len(tasks) == 13
     tasks = autotvm.task.extract_from_program(mod, target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense))
+                                              ops=(conv2d, dense))
+    assert len(tasks) == 13
+    tasks = autotvm.task.extract_from_program(mod, target=target,
+                                              params=params)
     assert len(tasks) == 13
 
     mod, params, _ = get_network('mobilenet', batch_size=1)
@@ -77,65 +83,19 @@ def test_task_extraction():
     params_list.append(params)
     tasks = autotvm.task.extract_from_program(mod, target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense))
+                                              ops=(conv2d, dense))
     assert len(tasks) == 20
 
     mod, params, _ = get_network('dcgan', batch_size=1)
     tasks = autotvm.task.extract_from_program(mod, target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d_transpose,))
+                                              ops=(conv2d_transpose,))
     assert len(tasks) == 4
 
     tasks = autotvm.task.extract_from_multiple_program(mod_list, params_list,
                                                        target=target,
-                                                       ops=(relay.op.nn.conv2d,))
+                                                       ops=(conv2d,))
     assert len(tasks) == 31
 
-def test_template_key_provided():
-    """test task extraction using non-'direct' template_key"""
-    target = 'llvm'
-
-    import topi
-    template_keys = {
-        # topi.nn.conv2d - is left blank to test fallback logic
-        topi.nn.dense: 'direct_nopack',
-        topi.nn.depthwise_conv2d_nchw: 'direct',
-    }
-
-    mod, params, _ = get_network('mobilenet', batch_size=1)
-    tasks = autotvm.task.extract_from_program(mod['main'], target=target,
-                                              params=params,
-                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense),
-                                              template_keys=template_keys)
-    for task in tasks:
-        if 'dense' in task.name:
-            assert task.config_space.template_key == 'direct_nopack'
-        else:
-            assert task.config_space.template_key == 'direct'
-
-def test_template_key_empty():
-    """test task extraction using empty template_key"""
-    target = 'llvm'
-    mod, params, _ = get_network('mobilenet', batch_size=1)
-    tasks = autotvm.task.extract_from_program(mod['main'], target=target,
-                                              params=params,
-                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense),
-                                              template_keys=None)
-    for task in tasks:
-        assert task.config_space.template_key == 'direct'
-
-def test_template_key_default():
-    """test task extraction without template_key"""
-    target = 'llvm'
-    mod, params, _ = get_network('mobilenet', batch_size=1)
-    tasks = autotvm.task.extract_from_program(mod['main'], target=target,
-                                              params=params,
-                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense))
-    for task in tasks:
-        assert task.config_space.template_key == 'direct'
-
 if __name__ == '__main__':
     test_task_extraction()
-    test_template_key_provided()
-    test_template_key_empty()
-    test_template_key_default()
index fd7ec18..4e4122a 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import numpy as np
 import tvm
 import tvm.testing
-import numpy as np
 from tvm import relay
+from tvm import autotvm
+import topi
+from tvm.relay.testing import run_infer_type
+from tvm.relay.testing.temp_op_attr import TempOpAttr
+
+
+@autotvm.register_topi_compute("test/conv2d_1")
+def _compute_conv2d_1(cfg, input, filter, strides, padding, dilation, out_dtype):
+    return topi.nn.conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
+
+@autotvm.register_topi_schedule("test/conv2d_1")
+def _schedule_conv2d_1(cfg, outs):
+    return topi.generic.schedule_conv2d_nchw(outs)
+
+@autotvm.register_topi_compute("test/conv2d_2")
+def _compute_conv2d_2(cfg, input, filter, strides, padding, dilation, out_dtype):
+    return topi.nn.conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
+
+@autotvm.register_topi_schedule("test/conv2d_2")
+def _schedule_conv2d_2(cfg, outs):
+    return topi.generic.schedule_conv2d_nchw(outs)
+
+def _compute_conv2d_3(input, filter, strides, padding, dilation, out_dtype):
+    return topi.nn.conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
+
+def _schedule_conv2d_3(outs):
+    return topi.generic.schedule_conv2d_nchw(outs)
+
+@tvm.target.override_native_generic_func("test_conv2d_strategy")
+def _tmp_strategy(attrs, inputs, out_type, target):
+    strategy = relay.op.OpStrategy()
+    strategy.add_implementation(
+        relay.op.strategy.wrap_compute_conv2d(_compute_conv2d_1),
+        relay.op.strategy.wrap_topi_schedule(_schedule_conv2d_1),
+        name="conv2d_1",
+        plevel=10)
+    strategy.add_implementation(
+        relay.op.strategy.wrap_compute_conv2d(_compute_conv2d_2),
+        relay.op.strategy.wrap_topi_schedule(_schedule_conv2d_2),
+        name="conv2d_2",
+        plevel=15)
+    ic = inputs[0].shape[1]
+    with tvm.te.SpecializedCondition(ic >= 16):
+        strategy.add_implementation(
+            relay.op.strategy.wrap_compute_conv2d(_compute_conv2d_3),
+            relay.op.strategy.wrap_topi_schedule(_schedule_conv2d_3),
+            name="conv2d_3",
+            plevel=20)
+    return strategy
+
+def _create_record(task_name, dshape, wshape, target, cost):
+    args = [tvm.placeholder(dshape), tvm.placeholder(wshape), (1, 1), (1, 1, 1, 1),
+            (1, 1), 'float32']
+    task = autotvm.task.create(task_name, args, target)
+    cfg = autotvm.ConfigEntity(0, None, {}, [])
+    cfg.cost = cost
+    inp = autotvm.MeasureInput(target=target, task=task, config=cfg)
+    result = autotvm.MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
+    return (inp, result)
+
+def test_get_valid_implementations():
+    target = tvm.target.create("llvm")
+
+    def _get_impls(dshape, wshape):
+        data = relay.var("data", shape=dshape)
+        weight = relay.var("wshape", shape=wshape)
+        out = relay.nn.conv2d(data, weight, padding=(1, 1))
+        out = run_infer_type(out)
+        return relay.backend.compile_engine.get_valid_implementations(
+            relay.op.get("nn.conv2d"),
+            out.attrs,
+            [tvm.placeholder(dshape), tvm.placeholder(wshape)],
+            out.checked_type,
+            target)
+
+    with TempOpAttr("nn.conv2d", "FTVMStrategy", _tmp_strategy):
+        impls = _get_impls((1, 8, 7, 7), (32, 8, 3, 3))
+        assert len(impls) == 2
+        impls = _get_impls((1, 16, 7, 7), (32, 16, 3, 3))
+        assert len(impls) == 3
+
+def test_select_implementation():
+    target = tvm.target.create("llvm")
+
+    def _select_impl(dshape, wshape, use_autotvm=False):
+        data = relay.var("data", shape=dshape)
+        weight = relay.var("wshape", shape=wshape)
+        out = relay.nn.conv2d(data, weight, padding=(1, 1))
+        out = run_infer_type(out)
+        return relay.backend.compile_engine.select_implementation(
+            relay.op.get("nn.conv2d"),
+            out.attrs,
+            [tvm.placeholder(dshape), tvm.placeholder(wshape)],
+            out.checked_type,
+            target,
+            use_autotvm)
+
+    with TempOpAttr("nn.conv2d", "FTVMStrategy", _tmp_strategy):
+        impl, _ = _select_impl((1, 8, 7, 7), (32, 8, 3, 3))
+        assert impl.name == "conv2d_2"
+        impl, _ = _select_impl((1, 8, 7, 7), (32, 8, 3, 3), True)
+        assert impl.name == "conv2d_2"
+        impl, _ = _select_impl((1, 16, 7, 7), (32, 16, 3, 3))
+        assert impl.name == "conv2d_3"
+        impl, _ = _select_impl((1, 16, 7, 7), (32, 16, 3, 3), True)
+        assert impl.name == "conv2d_3"
+
+        # add autotvm record
+        records = []
+        records.append(_create_record("test/conv2d_1", (1, 8, 7, 7), (32, 8, 3, 3), target, 0.5))
+        records.append(_create_record("test/conv2d_1", (1, 16, 7, 7), (32, 16, 3, 3), target, 1.0))
+        with target:
+            with autotvm.apply_history_best(records):
+                impl, _ = _select_impl((1, 8, 7, 7), (32, 8, 3, 3), True)
+                assert impl.name == "conv2d_1"
+                impl, _ = _select_impl((1, 16, 7, 7), (32, 16, 3, 3), True)
+                assert impl.name == "conv2d_1"
 
+        records.append(_create_record("test/conv2d_2", (1, 8, 7, 7), (32, 8, 3, 3), target, 0.2))
+        records.append(_create_record("test/conv2d_1", (1, 16, 7, 7), (32, 16, 3, 3), target, 1.2))
+        with target:
+            with autotvm.apply_history_best(records):
+                impl, _ = _select_impl((1, 8, 7, 7), (32, 8, 3, 3), True)
+                assert impl.name == "conv2d_2"
+                impl, _ = _select_impl((1, 16, 7, 7), (32, 16, 3, 3), True)
+                assert impl.name == "conv2d_1"
 
 def test_compile_engine():
     engine = relay.backend.compile_engine.get()
@@ -109,6 +234,8 @@ def test_compile_nhwc_pack():
 
 
 if __name__ == "__main__":
+    test_get_valid_implementations()
+    test_select_implementation()
     test_compile_engine()
     test_compile_placeholder_bypass()
     test_compile_injective_with_tuple()
index 0d3fd4b..d545d0c 100644 (file)
@@ -199,7 +199,7 @@ def test_conv2d_run():
             except_targets = []
 
         x = relay.var("x", shape=dshape, dtype=dtype)
-        w = relay.var("w", dtype=dtype)
+        w = relay.var("w", shape=kshape, dtype=dtype)
         y = relay.nn.conv2d(x, w,
                             padding=padding,
                             dilation=dilation,
@@ -222,7 +222,7 @@ def test_conv2d_run():
                 continue
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             op_res1 = intrp1.evaluate(func)(data, kernel)
-            tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-4, atol=1e-4)
 
     def compile_test_conv2d_arm_cpu(dtype, out_dtype, scale, dshape, kshape,
                         padding=(1, 1),
@@ -230,7 +230,7 @@ def test_conv2d_run():
                         dilation=(1, 1),
                         **attrs):
         x = relay.var("x", shape=dshape, dtype=dtype)
-        w = relay.var("w", dtype=dtype)
+        w = relay.var("w", shape=kshape, dtype=dtype)
         y = relay.nn.conv2d(x, w,
                             padding=padding,
                             dilation=dilation,
@@ -240,13 +240,13 @@ def test_conv2d_run():
         mod = tvm.IRModule()
         mod["main"] = func
 
-        test_schedule='{"i": ["llvm -device=arm_cpu", "topi_nn_depthwise_conv2d_nchw", \
+        test_schedule='{"i": ["llvm -device=arm_cpu", "depthwise_conv2d_nchw_spatial_pack.arm_cpu", \
                         [["TENSOR", [1, 512, 32, 32], "float32"], \
                         ["TENSOR", [512, 1, 3, 3], "float32"], \
                         [1, 1], [1, 1], [1, 1], "float32"], {}, \
-                        ["depthwise_conv2d_nchw", [1, 512, 32, 32, "float32"], \
+                        ["depthwise_conv2d_nchw_spatial_pack.arm_cpu", [1, 512, 32, 32, "float32"], \
                         [512, 1, 3, 3, "float32"], [1, 1], [1, 1], [1, 1], "float32"], \
-                        {"i": 743640, "t": "contrib_spatial_pack", "c": null, \
+                        {"i": 743640, "t": "", "c": null, \
                         "e": [["tile_co", "sp", [32, 16]], ["tile_oh", "sp", [8, 1]], \
                         ["tile_ow", "sp", [1, 8]], \
                         ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 8, 6, 7]], \
@@ -319,8 +319,8 @@ def test_conv2d_winograd():
             if key in self.memory:
                 return self.memory[key]
             cfg = autotvm.task.space.FallbackConfigEntity()
-            cfg.template_key = 'winograd'
             cfg.is_fallback = False
+            cfg.cost = 0.1 if 'winograd' in workload[0] else 1
             cfg['tile_b'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
             cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
             cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
@@ -1113,6 +1113,9 @@ def test_conv2d_int8_intrinsics():
         else:
             assert False, "Target should be Skylake or Cascadelake"
 
+    # TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout.
+    #   Re-enable this after adding conv2d_NCHWc_int8 support for NHWC.
+
     # compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions
     targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]
     llvm_version = tvm.target.codegen.llvm_version_major()
@@ -1127,11 +1130,11 @@ def test_conv2d_int8_intrinsics():
                                dtypes=dtypes)
                 assert _has_fast_int8_instructions(asm, target)
 
-            for ic in [1, 4, 6]:
-                asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC",
-                               kernel_layout='HWIO',
-                               dtypes=dtypes)
-                assert _has_fast_int8_instructions(asm, target)
+            for ic in [1, 4, 6]:
+                asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC",
+                               kernel_layout='HWIO',
+                               dtypes=dtypes)
+                assert _has_fast_int8_instructions(asm, target)
 
             # Sweep the output channels to check int8 robustness
             # Output channels should be a multiple of 16 internally.
@@ -1141,20 +1144,20 @@ def test_conv2d_int8_intrinsics():
                                dtypes=dtypes)
                 assert _has_fast_int8_instructions(asm, target)
 
-            for oc in [4, 16, 20]:
-                asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC",
-                               kernel_layout='HWIO',
-                               dtypes=dtypes)
-                assert _has_fast_int8_instructions(asm, target)
+            for oc in [4, 16, 20]:
+                asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC",
+                               kernel_layout='HWIO',
+                               dtypes=dtypes)
+                assert _has_fast_int8_instructions(asm, target)
 
             # Check that both non-divisible oc and ic work
             asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
                            dtypes=dtypes)
             assert _has_fast_int8_instructions(asm, target)
 
-            asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
-                           dtypes=dtypes)
-            assert _has_fast_int8_instructions(asm, target)
+            asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
+                           dtypes=dtypes)
+            assert _has_fast_int8_instructions(asm, target)
 
     # Check that int8 x int8 goes through legalization so that fast instructions can be picked up.
     for target in targets:
@@ -1165,16 +1168,16 @@ def test_conv2d_int8_intrinsics():
                            dtypes=dtypes)
             assert _has_fast_int8_instructions(asm, target)
 
-            asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
-                           dtypes=dtypes)
-            assert _has_fast_int8_instructions(asm, target)
+            asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
+                           dtypes=dtypes)
+            assert _has_fast_int8_instructions(asm, target)
 
     # Ensure that code is generated when datatypes are not HW supported.
-    dtypes = ('uint8', 'uint8', 'int32')
-    asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
-                   dtypes=dtypes)
-    # Check that intrinisic is not present in the assembly.
-    assert not _has_fast_int8_instructions(asm, target)
+    dtypes = ('uint8', 'uint8', 'int32')
+    asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
+                   dtypes=dtypes)
+    # Check that intrinisic is not present in the assembly.
+    assert not _has_fast_int8_instructions(asm, target)
 
     # Check that a vectorized instruction is generated for older Intel
     # generations, because we default to NCHWc layout.
@@ -1223,7 +1226,7 @@ def test_bitserial_conv2d_infer_type():
     y = relay.nn.bitserial_conv2d(
         x, w, kernel_size=(3, 3), padding=(0, 0), channels=32)
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
+    assert yy.checked_type == relay.TensorType(
         (n, 32, 222, 222), "int16")
 
 
@@ -1233,9 +1236,11 @@ def test_bitpack_infer_type():
     x = relay.var("x", relay.ty.TensorType((o, i, h, w), "int16"))
     y = relay.nn.bitpack(x, bit_axis=4, pack_axis=1, pack_type='uint16', bits=1)
     yy = run_infer_type(y)
-    assert yy.checked_type ==  relay.TensorType(
+    assert yy.checked_type == relay.TensorType(
         (32, 2, 128, 128, 1), "uint16")
 
+# TODO(@jwfromm): Need to add bitserial_conv2d & bitpack run test cases
+
 
 if __name__ == "__main__":
     test_pool1d()
index 67a7ef6..e827c72 100644 (file)
@@ -116,23 +116,13 @@ def get_funcs(data_shape,
               data_layout,
               kernel_layout,
               out_dtype,
-              groups=1):
+              groups=1,
+              channels=None):
     data = relay.var("data", shape=data_shape,
             dtype=data_dtype)
     kernel = relay.var("kernel", shape=kernel_shape,
             dtype=kernel_dtype)
 
-    if groups > 1:
-        channels = groups
-    elif kernel_layout == "OIHW":
-        channels = kernel_shape[0]
-    elif kernel_layout == "HWIO":
-        channels = kernel_shape[3]
-    elif kernel_layout == "HWOI":
-        channels = kernel_shape[2]
-    else:
-        raise NotImplementedError
-
     ref_func = get_ref_func(data,
                             kernel,
                             input_zero_point,
@@ -431,9 +421,9 @@ def test_layout():
                 kernel_shape, kernel_dtype)
 
         # NHWC and HWOI layout. Used in depthwise conv.
-        data_shape = (2, 2, 4, 1) # NHWC
+        data_shape = (2, 2, 4, 3) # NHWC
         data_dtype = 'uint8'
-        kernel_shape = (2, 2, 1, 1) # HWOI
+        kernel_shape = (2, 2, 3, 1) # HWOI
         kernel_dtype = 'uint8'
         ref_func, qnn_func = get_funcs(data_shape=data_shape,
                                        data_dtype=data_dtype,
@@ -447,6 +437,7 @@ def test_layout():
                                        padding=(0, 0),
                                        strides=(1, 1),
                                        dilation=(1, 1),
+                                       groups=3,
                                        data_layout="NHWC",
                                        kernel_layout="HWOI",
                                        out_dtype="int32")
@@ -826,7 +817,8 @@ def test_depthwise_depth_multiplier():
                                        data_layout="NCHW",
                                        kernel_layout="OIHW",
                                        out_dtype="int32",
-                                       groups=8)
+                                       groups=4,
+                                       channels=8)
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
 
@@ -875,7 +867,8 @@ def test_depthwise_depth_multiplier():
                                        data_layout="NHWC",
                                        kernel_layout="HWOI",
                                        out_dtype="int32",
-                                       groups=8)
+                                       groups=4,
+                                       channels=8)
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
 
index 2ec3f28..df01310 100644 (file)
@@ -15,8 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 """Test alter op layout pass"""
-import tvm
+import pytest
 
+import tvm
 from tvm import relay
 from tvm.relay import transform, analysis
 from tvm.relay.testing.temp_op_attr import TempOpAttr
@@ -43,7 +44,7 @@ def test_alter_op():
         y = relay.Function([x, weight], y)
         return y
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         weight = relay.multiply(weight, relay.const(2.0, "float32"))
         return relay.nn.conv2d(data, weight, **attrs)
@@ -77,7 +78,7 @@ def test_alter_return_none():
 
     called = [False]
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         called[0] = True
         return None
 
@@ -109,7 +110,7 @@ def test_alter_layout():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -176,7 +177,7 @@ def test_alter_layout_dual_path():
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -236,7 +237,7 @@ def test_alter_layout_resnet():
         y = relay.nn.global_max_pool2d(y)
         return relay.Function(analysis.free_vars(y), y)
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -285,7 +286,7 @@ def test_alter_layout_broadcast_op():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -342,7 +343,7 @@ def test_alter_layout_broadcast_scalar_op():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -394,7 +395,7 @@ def test_alter_layout_scalar():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -427,7 +428,7 @@ def test_alter_layout_scalar():
 
 def test_alter_layout_concatenate():
     """ NCHW, NHWC and corner case concatenate layout transform."""
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -536,7 +537,7 @@ def test_alter_layout_nchw_upsamping_op():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -572,7 +573,7 @@ def test_alter_layout_strided_slice():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW4c'
@@ -607,9 +608,9 @@ def test_alter_layout_depthwise_conv2d():
         return y
 
     import topi
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         with tvm.target.create("llvm"):
-            return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, relay)
+            return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
 
 
     def expected():
@@ -643,7 +644,7 @@ def test_alter_layout_prelu():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -675,7 +676,7 @@ def test_alter_layout_prelu():
 
 def test_alter_layout_pad():
     """ Check NCHW, NHWC and corner case for pad layout conversion"""
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -785,7 +786,7 @@ def test_alter_layout_pad():
 
 def test_alter_layout_pool():
     """ Check NCHW, NHWC pool layout conversion"""
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -862,7 +863,7 @@ def test_alter_layout_pool():
 
 def test_alter_layout_sum():
     """ Check NCHW, NHWC sum layout conversion"""
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         new_attrs = dict(attrs)
         new_attrs['data_layout'] = 'NCHW16c'
@@ -938,11 +939,15 @@ def test_alter_layout_sum():
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
 
 
+# TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be the
+#   right behavior of alter_layout
+@pytest.mark.skip
 def test_alter_layout_nhwc_nchw_arm():
     """ Check NHWC to NHCW conversion for a small sequence of ops."""
-    def alter_conv2d(attrs, inputs, tinfos):
-        from topi.arm_cpu.conv2d import _alter_conv2d_layout_arm
-        return _alter_conv2d_layout_arm(attrs, inputs, tinfos, tvm.relay)
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
+        import topi
+        with tvm.target.create("llvm -device=arm_cpu"):
+            return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
 
     # Check NHWC conversion.
     def before_nhwc():
@@ -1011,7 +1016,7 @@ def test_alter_op_with_global_var():
         mod["main"] = relay.Function([x, weight], foo(x, weight))
         return mod
 
-    def alter_conv2d(attrs, inputs, tinfos):
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
         data, weight = inputs
         weight = relay.multiply(weight, relay.const(2.0, "float32"))
         return relay.nn.conv2d(data, weight, **attrs)
@@ -1054,5 +1059,5 @@ if __name__ == "__main__":
     test_alter_layout_pad()
     test_alter_layout_pool()
     test_alter_layout_sum()
-    test_alter_layout_nhwc_nchw_arm()
+    test_alter_layout_nhwc_nchw_arm()
     test_alter_op_with_global_var()
index 443d2e4..02438ef 100644 (file)
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
+import pytest
+
 import tvm
 from tvm import relay
 from tvm.relay import testing
@@ -55,7 +57,8 @@ def get_calibration_dataset(input_name):
     return dataset
 
 
-def test_calibrate_target(create_target=False):
+@pytest.mark.parametrize("create_target", [True, False])
+def test_calibrate_target(create_target):
     mod, params = testing.resnet.get_workload(num_layers=18)
     dataset = get_calibration_dataset("data")
     with relay.quantize.qconfig(calibrate_mode="kl_divergence"):
index 1399573..bfc3cab 100644 (file)
@@ -131,12 +131,13 @@ def test_fold_fwd_dual_path():
         z = relay.add(y1, y2)
         return relay.Function(args, z)
 
-    def check(shape, channels):
-        x =  relay.var("x", shape=shape)
-        in_channels = shape[-1]
+    def check(dshape, channels):
+        x =  relay.var("x", shape=dshape)
+        in_channels = dshape[-1]
         # test depthwise
         assert in_channels == channels
-        weight = relay.var("weight")
+        wshape = (3, 3, 1, channels) # HWIO
+        weight = relay.var("weight", shape=wshape)
         in_bias = relay.var("in_bias", shape=(in_channels,))
         in_scale = relay.const(_get_positive_scale(in_channels,))
         y1 = before(x, weight, in_bias, in_scale, channels)
index fac9f06..83bbd54 100644 (file)
@@ -36,7 +36,7 @@ class DummyRunner(Runner):
     def get_build_kwargs(self):
         return {}
 
-@autotvm.template
+@autotvm.register_customized_task("testing/matmul")
 def matmul(N, L, M, dtype):
     A = tvm.placeholder((N, L), name='A', dtype=dtype)
     B = tvm.placeholder((L, M), name='B', dtype=dtype)
@@ -63,7 +63,7 @@ def matmul(N, L, M, dtype):
 
     return s, [A, B, C]
 
-@autotvm.template
+@autotvm.register_customized_task("testing/bad_matmul")
 def bad_matmul(N, L, M, dtype):
     if 'bad_device' in tvm.target.Target.current().keys:
         A = tvm.placeholder((N, L), name='A', dtype=dtype)
@@ -85,7 +85,7 @@ def bad_matmul(N, L, M, dtype):
 def get_sample_task(n=128):
     """return a sample task for testing"""
     target = tvm.target.create("llvm")
-    task = autotvm.task.create(matmul, args=(n, n, n, 'float32'), target=target)
+    task = autotvm.task.create("testing/matmul", args=(n, n, n, 'float32'), target=target)
     return task, target
 
 def get_sample_records(n):
index 716ab7f..5a55c4f 100644 (file)
 The dispatcher can choose which template to use according
 to the parameters of workload"""
 
-from collections import namedtuple
 from tvm import autotvm
-from tvm.autotvm.task import dispatcher, DispatchContext
-
-SimpleConfig = namedtuple('SimpleConfig', ('template_key', 'is_fallback'))
-
-def test_dispatch():
-    @dispatcher
-    def my_dispatcher(a, b):
-        return (a, b)
-
-    @my_dispatcher.register("im2col")
-    def _im2col(cfg, a, b):
-        return a
-
-    @my_dispatcher.register("spatial_pack")
-    def _spatial_pack(cfg, a, b):
-        return b
-
-    class SimpleDispatcher(DispatchContext):
-        def query(self, target, workload):
-            a, b = workload
-            tkey = "spatial_pack" if a + b > 2 else "im2col"
-            cfg = SimpleConfig(tkey, False)
-            return cfg
-
-    with SimpleDispatcher():
-        # this will call im2col
-        assert my_dispatcher(1, 0) == 1
-
-        # this will call spatial pack
-        assert my_dispatcher(1, 100) == 100
 
 def test_fallback():
 
-    @autotvm.template
+    @autotvm.register_customized_task("testing/dispatch/fallback")
     def simple_template(a, b):
         cfg = autotvm.get_config()
         assert cfg.is_fallback
@@ -62,5 +31,4 @@ def test_fallback():
 
 
 if __name__ == "__main__":
-    test_dispatch()
     test_fallback()
index 48a1d31..0899f6f 100644 (file)
@@ -64,7 +64,7 @@ def test_check_correctness():
     # a bad template
     n = 128
     target = tvm.target.create("llvm -device=bad_device")
-    task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target)
+    task = autotvm.task.create("testing/bad_matmul", args=(n, n, n, 'float32'), target=target)
 
     def _callback_wrong(tuner, measure_inputs, measure_results):
         for _, res in zip(measure_inputs, measure_results):
index 2e0cee2..c14607d 100644 (file)
@@ -101,4 +101,4 @@ def test_system_lib():
 
 if __name__ == "__main__":
     test_resnet18()
-    test_system_lib()
+    #test_system_lib()
index ec36a5f..8652817 100644 (file)
@@ -305,7 +305,7 @@ def test_cuda_reduction():
         e = topi.elemwise_sum([c, d])
         g = topi.sum(e)
         with tvm.target.cuda():
-            sg = topi.generic.schedule_reduce(g)
+            sg = topi.cuda.schedule_reduce(g)
             ctx = tvm.gpu(0)
             func = tvm.build(sg, [a, b, g], 'cuda')
             a_np = np.random.uniform(size=(m, n)).astype(a.dtype)
index a8b22fd..27e077f 100644 (file)
@@ -31,7 +31,6 @@ from tvm import relay
 from tvm.autotvm.task import ConfigEntity
 from tvm.autotvm.measure import MeasureResult, MeasureInput
 from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
-from test_graph_tuner_utils import create_workload
 
 
 def _create_data(target, dshape, dtype, layout):
@@ -48,68 +47,53 @@ def _create_data(target, dshape, dtype, layout):
     tasks = autotvm.task.extract_from_program(mod["main"],
                                               target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d,))
-    wkl_list = [
-        create_workload((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-        create_workload((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
-        create_workload((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-    ]
+                                              ops=(relay.op.get("nn.conv2d"),))
     costs = [0.04, 0.012, 0.03]
     config_list = []
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [3, 1]],
-                      ["tile_oc", "sp", [4, 4]],
-                      ["tile_ow", "sp", [4, 2]],
-                      ["unroll_kw", "ot", True]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [3, 1]],
+                           ["tile_oc", "sp", [4, 4]],
+                           ["tile_ow", "sp", [4, 2]],
+                           ["unroll_kw", "ot", True]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [2, 8]],
-                      ["tile_oc", "sp", [1, 32]],
-                      ["tile_oh", "ot", 1],
-                      ["tile_ow", "sp", [4, 2]]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [2, 8]],
+                           ["tile_oc", "sp", [1, 32]],
+                           ["tile_oh", "ot", 1],
+                           ["tile_ow", "sp", [4, 2]]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [8, 4]],
-                      ["tile_oc", "sp", [4, 8]],
-                      ["tile_ow", "sp", [2, 4]],
-                      ["unroll_kw", "ot", False]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [8, 4]],
+                           ["tile_oc", "sp", [4, 8]],
+                           ["tile_ow", "sp", [2, 4]],
+                           ["unroll_kw", "ot", False]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
 
     records = []
-    for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
-        task.workload = wkl
+    for cost, config, task in zip(costs, config_list, tasks):
         ms_input = MeasureInput(target=target, task=task, config=config)
         ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
         records.append((ms_input, ms_output))
 
     ltf_records = []
     ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
-    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
-    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
-    ltf_task = copy.deepcopy(tasks[0])
-    ltf_task.workload = ltf_wkl
+    ltf_task = autotvm.task.create('layout_transform', ltf_arg, target)
     ms_input = MeasureInput(target=target, task=ltf_task, config=None)
     ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
     ltf_records.append((ms_input, ms_output))
 
     ltf_keys = []
     ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"]
-    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
-    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
+    ltf_wkl = autotvm.task.args_to_workload(ltf_arg, 'layout_transform')
     ltf_keys.append(ltf_wkl)
     ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"]
-    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
-    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
+    ltf_wkl = autotvm.task.args_to_workload(ltf_arg, 'layout_transform')
     ltf_keys.append(ltf_wkl)
     ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"]
-    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
-    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
+    ltf_wkl = autotvm.task.args_to_workload(ltf_arg, 'layout_transform')
     ltf_keys.append(ltf_wkl)
 
     return net, records, ltf_records, ltf_keys, tasks
@@ -121,7 +105,8 @@ def test_graph_tuner_layout_transform():
     dshape = (1, 3, 8, 8)
     dtype = "float32"
     layout = "NCHW"
-    target_ops = [relay.nn.conv2d]
+    conv2d = relay.op.get("nn.conv2d")
+    target_ops = [conv2d]
 
     g, records, ltf_records, ltf_keys, _ = _create_data(target, dshape, dtype, layout)
     executor = DPTuner(g, {"data": dshape}, records, target_ops, target=target, log_file=log_file)
@@ -156,36 +141,34 @@ def test_DPTuner_run():
     dtype = "float32"
     layout = "NCHW"
     dshape = (1, 3, 8, 8)
-    target_ops = [relay.nn.conv2d]
+    conv2d = relay.op.get("nn.conv2d")
+    target_ops = [conv2d]
 
     g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
     mod = tvm.IRModule()
     mod["main"] = g
     costs = [0.02, 0.02, 0.045]
     config_list = []
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [1, 3]],
-                      ["tile_oc", "sp", [2, 8]],
-                      ["tile_ow", "sp", [4, 2]],
-                      ["unroll_kw", "ot", True]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [1, 3]],
+                           ["tile_oc", "sp", [2, 8]],
+                           ["tile_ow", "sp", [4, 2]],
+                           ["unroll_kw", "ot", True]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [4, 4]],
-                      ["tile_oc", "sp", [2, 16]],
-                      ["tile_oh", "ot", 1],
-                      ["tile_ow", "sp", [4, 2]]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [4, 4]],
+                           ["tile_oc", "sp", [2, 16]],
+                           ["tile_oh", "ot", 1],
+                           ["tile_ow", "sp", [4, 2]]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [16, 2]],
-                      ["tile_oc", "sp", [8, 4]],
-                      ["tile_ow", "sp", [2, 4]],
-                      ["unroll_kw", "ot", False]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [16, 2]],
+                           ["tile_oc", "sp", [8, 4]],
+                           ["tile_ow", "sp", [2, 4]],
+                           ["unroll_kw", "ot", False]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
     for cost, config, task in zip(costs, config_list, tasks):
         ms_input = MeasureInput(target=target, task=task, config=config)
@@ -207,34 +190,32 @@ def test_PBQPTuner_run():
     dtype = "float32"
     layout = "NCHW"
     dshape = (1, 3, 8, 8)
-    target_ops = [relay.nn.conv2d]
+    conv2d = relay.op.get("nn.conv2d")
+    target_ops = [conv2d]
 
     g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
     costs = [0.02, 0.02, 0.045]
     config_list = []
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [1, 3]],
-                      ["tile_oc", "sp", [2, 8]],
-                      ["tile_ow", "sp", [4, 2]],
-                      ["unroll_kw", "ot", True]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [1, 3]],
+                           ["tile_oc", "sp", [2, 8]],
+                           ["tile_ow", "sp", [4, 2]],
+                           ["unroll_kw", "ot", True]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [4, 4]],
-                      ["tile_oc", "sp", [2, 16]],
-                      ["tile_oh", "ot", 1],
-                      ["tile_ow", "sp", [4, 2]]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [4, 4]],
+                           ["tile_oc", "sp", [2, 16]],
+                           ["tile_oh", "ot", 1],
+                           ["tile_ow", "sp", [4, 2]]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [16, 2]],
-                      ["tile_oc", "sp", [8, 4]],
-                      ["tile_ow", "sp", [2, 4]],
-                      ["unroll_kw", "ot", False]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [16, 2]],
+                           ["tile_oc", "sp", [8, 4]],
+                           ["tile_ow", "sp", [2, 4]],
+                           ["unroll_kw", "ot", False]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
     for cost, config, task in zip(costs, config_list, tasks):
         ms_input = MeasureInput(target=target, task=task, config=config)
@@ -255,7 +236,8 @@ def test_many_sub_graphs():
     dtype = "float32"
     dshape = (1, 8, 8, 3)
     layout = "NCHW"
-    target_ops = [relay.nn.conv2d]
+    conv2d = relay.op.get("nn.conv2d")
+    target_ops = [conv2d]
 
     data = relay.var("data", shape=dshape, dtype=dtype)
     t0 = relay.transpose(data, (0, 3, 1, 2))
@@ -277,79 +259,63 @@ def test_many_sub_graphs():
     tasks = autotvm.task.extract_from_program(net["main"],
                                               target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d,))
-    wkl_list = [
-        create_workload((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-        create_workload((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
-        create_workload((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-    ]
+                                              ops=(conv2d,))
     costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
     config_list = []
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [3, 1]],
-                      ["tile_oc", "sp", [4, 4]],
-                      ["tile_ow", "sp", [4, 2]],
-                      ["unroll_kw", "ot", True]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [3, 1]],
+                           ["tile_oc", "sp", [4, 4]],
+                           ["tile_ow", "sp", [4, 2]],
+                           ["unroll_kw", "ot", True]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [2, 8]],
-                      ["tile_oc", "sp", [1, 32]],
-                      ["tile_oh", "ot", 1],
-                      ["tile_ow", "sp", [4, 2]]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [2, 8]],
+                           ["tile_oc", "sp", [1, 32]],
+                           ["tile_oh", "ot", 1],
+                           ["tile_ow", "sp", [4, 2]]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [8, 4]],
-                      ["tile_oc", "sp", [4, 8]],
-                      ["tile_ow", "sp", [2, 4]],
-                      ["unroll_kw", "ot", False]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [8, 4]],
+                           ["tile_oc", "sp", [4, 8]],
+                           ["tile_ow", "sp", [2, 4]],
+                           ["unroll_kw", "ot", False]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [1, 3]],
-                      ["tile_oc", "sp", [2, 8]],
-                      ["tile_ow", "sp", [4, 2]],
-                      ["unroll_kw", "ot", True]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [1, 3]],
+                           ["tile_oc", "sp", [2, 8]],
+                           ["tile_ow", "sp", [4, 2]],
+                           ["unroll_kw", "ot", True]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [4, 4]],
-                      ["tile_oc", "sp", [2, 16]],
-                      ["tile_oh", "ot", 1],
-                      ["tile_ow", "sp", [4, 2]]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [4, 4]],
+                           ["tile_oc", "sp", [2, 16]],
+                           ["tile_oh", "ot", 1],
+                           ["tile_ow", "sp", [4, 2]]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [16, 2]],
-                      ["tile_oc", "sp", [8, 4]],
-                      ["tile_ow", "sp", [2, 4]],
-                      ["unroll_kw", "ot", False]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [16, 2]],
+                           ["tile_oc", "sp", [8, 4]],
+                           ["tile_ow", "sp", [2, 4]],
+                           ["unroll_kw", "ot", False]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
 
     records = []
 
-    wkl_list = wkl_list + wkl_list
     tasks = tasks + tasks
-    for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
-        task.workload = wkl
+    for cost, config, task in zip(costs, config_list, tasks):
         ms_input = MeasureInput(target=target, task=task, config=config)
         ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
         records.append((ms_input, ms_output))
 
     ltf_records = []
     ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
-    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
-    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
-    ltf_task = copy.deepcopy(tasks[0])
-    ltf_task.workload = ltf_wkl
+    ltf_task = autotvm.task.create('layout_transform', ltf_arg, target)
     ms_input = MeasureInput(target=target, task=ltf_task, config=None)
     ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
     ltf_records.append((ms_input, ms_output))
@@ -376,7 +342,8 @@ def test_tuple():
     dtype = "float32"
     dshape = (1, 5, 32, 32)
     layout = "NCHW"
-    target_ops = [relay.nn.conv2d]
+    conv2d = relay.op.get("nn.conv2d")
+    target_ops = [conv2d]
 
     data = relay.var("data", shape=dshape, dtype=dtype)
     w0 = relay.var("w0_weight")
@@ -390,62 +357,48 @@ def test_tuple():
     tasks = autotvm.task.extract_from_program(net["main"],
                                               target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d,))
-    wkl_list = [
-        create_workload((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-        create_workload((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-    ]
+                                              ops=(conv2d,))
     costs = [0.01, 0.012, 0.03, 0.04]
     config_list = []
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [1, 5]],
-                      ["tile_oc", "sp", [1, 2]],
-                      ["tile_ow", "sp", [4, 8]],
-                      ["unroll_kw", "ot", True]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [1, 5]],
+                           ["tile_oc", "sp", [1, 2]],
+                           ["tile_ow", "sp", [4, 8]],
+                           ["unroll_kw", "ot", True]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [1, 5]],
-                      ["tile_oc", "sp", [1, 3]],
-                      ["tile_ow", "sp", [2, 16]],
-                      ["unroll_kw", "ot", False]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [1, 5]],
+                           ["tile_oc", "sp", [1, 3]],
+                           ["tile_ow", "sp", [2, 16]],
+                           ["unroll_kw", "ot", False]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [1, 5]],
-                      ["tile_oc", "sp", [2, 1]],
-                      ["tile_ow", "sp", [4, 8]],
-                      ["unroll_kw", "ot", True]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [1, 5]],
+                           ["tile_oc", "sp", [2, 1]],
+                           ["tile_ow", "sp", [4, 8]],
+                           ["unroll_kw", "ot", True]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [1, 5]],
-                      ["tile_oc", "sp", [3, 1]],
-                      ["tile_ow", "sp", [2, 16]],
-                      ["unroll_kw", "ot", False]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [1, 5]],
+                           ["tile_oc", "sp", [3, 1]],
+                           ["tile_ow", "sp", [2, 16]],
+                           ["unroll_kw", "ot", False]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
 
     records = []
-
-    wkl_list = wkl_list + wkl_list
     tasks = tasks + tasks
-    for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
-        task.workload = wkl
+    for cost, config, task in zip(costs, config_list, tasks):
         ms_input = MeasureInput(target=target, task=task, config=config)
         ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
         records.append((ms_input, ms_output))
 
     ltf_records = []
     ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
-    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
-    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
-    ltf_task = copy.deepcopy(tasks[0])
-    ltf_task.workload = ltf_wkl
+    ltf_task = autotvm.task.create('layout_transform', ltf_arg, target)
     ms_input = MeasureInput(target=target, task=ltf_task, config=None)
     ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
     ltf_records.append((ms_input, ms_output))
@@ -472,7 +425,8 @@ def test_triangle_block():
     dtype = "float32"
     dshape = (1, 3, 8, 8)
     layout = "NCHW"
-    target_ops = [relay.nn.conv2d]
+    conv2d = relay.op.get("nn.conv2d")
+    target_ops = [conv2d]
 
     data = relay.var("data", shape=dshape, dtype=dtype)
     w0 = relay.var("w0_weight")
@@ -488,79 +442,63 @@ def test_triangle_block():
     tasks = autotvm.task.extract_from_program(net["main"],
                                               target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d,))
-    wkl_list = [
-        create_workload((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-        create_workload((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
-        create_workload((1, 3, 8, 8), (32, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
-    ]
+                                              ops=(conv2d,))
     costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
     config_list = []
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [3, 1]],
-                      ["tile_oc", "sp", [4, 4]],
-                      ["tile_ow", "sp", [4, 2]],
-                      ["unroll_kw", "ot", True]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [3, 1]],
+                           ["tile_oc", "sp", [4, 4]],
+                           ["tile_ow", "sp", [4, 2]],
+                           ["unroll_kw", "ot", True]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [2, 8]],
-                      ["tile_oc", "sp", [1, 32]],
-                      ["tile_oh", "ot", 1],
-                      ["tile_ow", "sp", [4, 2]]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [2, 8]],
+                           ["tile_oc", "sp", [1, 32]],
+                           ["tile_oh", "ot", 1],
+                           ["tile_ow", "sp", [4, 2]]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [8, 4]],
-                      ["tile_oc", "sp", [4, 8]],
-                      ["tile_ow", "sp", [2, 4]],
-                      ["unroll_kw", "ot", False]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [8, 4]],
+                           ["tile_oc", "sp", [4, 8]],
+                           ["tile_ow", "sp", [2, 4]],
+                           ["unroll_kw", "ot", False]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [1, 3]],
-                      ["tile_oc", "sp", [2, 8]],
-                      ["tile_ow", "sp", [4, 2]],
-                      ["unroll_kw", "ot", True]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [1, 3]],
+                           ["tile_oc", "sp", [2, 8]],
+                           ["tile_ow", "sp", [4, 2]],
+                           ["unroll_kw", "ot", True]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [4, 4]],
-                      ["tile_oc", "sp", [2, 16]],
-                      ["tile_oh", "ot", 1],
-                      ["tile_ow", "sp", [4, 2]]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [4, 4]],
+                           ["tile_oc", "sp", [2, 16]],
+                           ["tile_oh", "ot", 1],
+                           ["tile_ow", "sp", [4, 2]]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
-    cfg_dict = {"i": -1,
-                "c": None,
-                "e": [["tile_ic", "sp", [16, 2]],
-                      ["tile_oc", "sp", [8, 4]],
-                      ["tile_ow", "sp", [2, 4]],
-                      ["unroll_kw", "ot", False]],
-                "t": ""}
+    cfg_dict = {"index": -1,
+                "code_hash": None,
+                "entity": [["tile_ic", "sp", [16, 2]],
+                           ["tile_oc", "sp", [8, 4]],
+                           ["tile_ow", "sp", [2, 4]],
+                           ["unroll_kw", "ot", False]]}
     config_list.append(ConfigEntity.from_json_dict(cfg_dict))
 
     records = []
 
-    wkl_list = wkl_list + wkl_list
     tasks = tasks + tasks
-    for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
-        task.workload = wkl
+    for cost, config, task in zip(costs, config_list, tasks):
         ms_input = MeasureInput(target=target, task=task, config=config)
         ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
         records.append((ms_input, ms_output))
 
     ltf_records = []
     ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
-    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
-    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
-    ltf_task = copy.deepcopy(tasks[0])
-    ltf_task.workload = ltf_wkl
+    ltf_task = autotvm.task.create('layout_transform', ltf_arg, target)
     ms_input = MeasureInput(target=target, task=ltf_task, config=None)
     ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
     ltf_records.append((ms_input, ms_output))
index 397ea23..112c5b8 100644 (file)
@@ -26,17 +26,7 @@ from tvm import autotvm, relay
 from tvm.relay.testing import resnet
 from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
     get_out_nodes, expr2graph, bind_inputs
-from tvm.relay.expr import Call, TupleGetItem, Tuple
-from topi.nn.conv2d import conv2d
-
-
-def create_workload(dshape, kshape, strides,
-                    padding, dilation, layout,
-                    out_layout, dtype, out_dtype):
-    data = tvm.placeholder(dshape, dtype=dtype)
-    kernel = tvm.placeholder(kshape, dtype=dtype)
-    return autotvm.task.args_to_workload([data, kernel, strides, padding, dilation, layout,
-                                          out_dtype], conv2d)
+from tvm.relay.expr import Call, TupleGetItem, Tuple, Var
 
 
 def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
@@ -53,7 +43,7 @@ def test_has_multiple_inputs():
     out = relay.add(out1, out2)
     net = relay.Function(relay.analysis.free_vars(out), out)
     net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)})
-    target_ops = ["conv2d"]
+    target_ops = [relay.op.get("nn.conv2d")]
     node_list = []
     node_dict = {}
     expr2graph(net, target_ops, node_dict, node_list)
@@ -67,22 +57,17 @@ def test_expr2graph():
     mod, _ = resnet.get_workload(num_layers=50, batch_size=1)
     node_dict = {}
     node_list = []
-    target_ops = ["conv2d"]
+    target_ops = [relay.op.get("nn.conv2d")]
     op_name_list = []
     def _count_node(node):
-        if not isinstance(node, relay.op.op.Op,):
-            return
         if isinstance(node, Call):
-            op_name_list.append(node.op.name.split(".")[-1])
-        elif isinstance(node, TupleGetItem):
-            op_name_list.append("TupleGetItem")
-        elif isinstance(node, Tuple):
-            op_name_list.append("Tuple")
-        else:
-            op_name_list.append("null")
+            op_name_list.append(node.op)
+        elif isinstance(node, (Var, TupleGetItem, Tuple)):
+            op_name_list.append(None)
     relay.analysis.post_order_visit(mod["main"], _count_node)
 
     expr2graph(mod["main"], target_ops, node_dict, node_list)
+    assert len(node_list) == len(op_name_list)
     for i, item in enumerate(zip(op_name_list, node_list)):
         op_name, node = item
         assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \
@@ -99,7 +84,7 @@ def test_get_direct_ancestor():
     out = relay.nn.conv2d(out3, w1)
     net = relay.Function(relay.analysis.free_vars(out), out)
     net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
-    target_ops = ["conv2d"]
+    target_ops = [relay.op.get("nn.conv2d")]
     node_list = []
     node_dict = {}
     expr2graph(net, target_ops, node_dict, node_list)
@@ -119,7 +104,7 @@ def test_get_in_nodes():
     out = relay.nn.conv2d(out3, w1)
     net = relay.Function(relay.analysis.free_vars(out), out)
     net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
-    target_ops = ["conv2d"]
+    target_ops = [relay.op.get("nn.conv2d")]
     input_names = ["data"]
     node_list = []
     node_dict = {}
index 98fdeaa..01c0d26 100644 (file)
@@ -108,7 +108,7 @@ def verify_tensor_scalar_bop(shape, typ="add"):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_elemwise(B)
+            s = topi.testing.get_elemwise_schedule(device)(B)
 
         k_ = 2
         foo = tvm.build(s, [A, B, k] + sh, device, name="tensor_scalar_" + typ)
@@ -154,7 +154,7 @@ def verify_broadcast_bop(lhs_shape, rhs_shape, typ="add"):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_broadcast(C)
+            s = topi.testing.get_broadcast_schedule(device)(C)
 
         foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
         lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
@@ -190,12 +190,14 @@ def verify_conv2d_scalar_bop(batch, in_size, in_channel, num_filter, kernel, str
             return
         print("Running on target: %s" % device)
 
+        conv2d_nchw, schedule_conv2d_nchw = topi.testing.get_conv2d_nchw_implement(device)
+
         k = 10.0
         dilation = (1, 1)
         with tvm.target.create(device):
             A = tvm.placeholder((batch, in_channel, in_size, in_size), name='A')
             W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
-            B = topi.nn.conv2d(A, W, stride, padding, dilation)
+            B = conv2d_nchw(A, W, stride, padding, dilation, A.dtype)
             if typ == "add":
                 C = B + k
             elif typ == "sub":
@@ -206,7 +208,7 @@ def verify_conv2d_scalar_bop(batch, in_size, in_channel, num_filter, kernel, str
                 C = B / k
             else:
                 raise NotImplementedError()
-            s = topi.generic.schedule_conv2d_nchw([C])
+            s = schedule_conv2d_nchw([C])
 
         foo = tvm.build(s, [A, W, B, C], device, name="conv2d_scalar_" + typ)
 
index 1b42308..bfc209d 100644 (file)
@@ -35,13 +35,10 @@ using namespace tvm::te;
 namespace cuda {
 /*!
 * \brief Create a CUDA schedule for LRN
-*
-* \param target The target to generate a schedule for.
 * \param outs The output tensors.
-*
 * \return A schedule for the given ops.
 */
-inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
+inline Schedule schedule_lrn(const Array<Tensor>& outs) {
   Array<Operation> out_ops;
   for (auto t : outs) {
     out_ops.push_back(t->op);
index 692370d..303f4a8 100644 (file)
@@ -34,14 +34,11 @@ using namespace tvm::te;
 namespace rocm {
 /*!
 * \brief Create a rocm schedule for LRN
-*
-* \param target The target to generate a schedule for.
 * \param outs The output tensors.
-*
 * \return A schedule for the given ops.
 */
-inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
-  return topi::cuda::schedule_lrn(target, outs);
+inline Schedule schedule_lrn(const Array<Tensor>& outs) {
+  return topi::cuda::schedule_lrn(outs);
 }
 
 }  // namespace rocm
index a0c6ab0..f1019e6 100644 (file)
@@ -40,6 +40,7 @@ from .transform import *
 from .broadcast import *
 from .sort import *
 from .argwhere import *
+from . import generic
 from . import nn
 from . import x86
 from . import cuda
index 32f4e87..c2a9ade 100644 (file)
@@ -16,7 +16,6 @@
 # under the License.
 # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
 """Argwhere operator"""
-import tvm
 from tvm import hybrid
 
 @hybrid.script
@@ -164,7 +163,6 @@ def hybrid_argwhere_5d(output_shape, condition):
                             valid_index += 1
     return a
 
-@tvm.target.generic_func
 def argwhere(output_shape, condition):
     """Find the indices of elements of a tensor that are non-zero.
 
index 517941c..eb05dd8 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+# pylint: disable=wildcard-import
 """Schedule for ARM CPU"""
 
-from . import conv2d
-from . import depthwise_conv2d
-from . import conv2d_transpose
-from . import conv2d_int8
-from . import bitserial_conv2d
-from . import bitserial_dense
-from . import injective
+from .conv2d import *
+from .depthwise_conv2d import *
+from .conv2d_transpose import *
+from .conv2d_int8 import *
+from . import conv2d_alter_op
+from .bitserial_conv2d import *
+from .bitserial_dense import *
+from .injective import *
index 4de2b14..d28ec09 100644 (file)
@@ -22,11 +22,10 @@ from tvm import autotvm
 from tvm import relay
 from .. import tag
 from ..nn.pad import pad
-from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc, bitserial_conv2d_legalize
+from ..nn.bitserial_conv2d import bitserial_conv2d_legalize
 from ..nn.bitserial_util import bitpack, binary_op_multiplier
 from ..nn.util import get_pad_tuple
 from ..util import get_const_int, get_const_tuple
-from .. import generic
 
 def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC, use_bitpack=True):
     if use_bitpack:
@@ -38,9 +37,9 @@ def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC, use_bitpack=True):
     return tvm.compute(kvshape, lambda co, dh, dw, b, vc, ci: \
         kernel_q[dh][dw][b][ci][co*VC+vc], name='kernel_vec')
 
-@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'arm_cpu', 'direct')
-def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weight_bits,
-                      pack_dtype, out_dtype, unipolar):
+@autotvm.register_topi_compute("bitserial_conv2d_nhwc.arm_cpu")
+def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, activation_bits, weight_bits,
+                          pack_dtype, out_dtype, unipolar):
     """ Compute convolution with pack on spatial axes. """
     assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
     assert pack_dtype == 'uint8', "only support packing into uint8 bits"
@@ -323,7 +322,7 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
     s[last].parallel(oh)
     return s
 
-@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, 'arm_cpu', 'direct')
+@autotvm.register_topi_schedule("bitserial_conv2d_nhwc.arm_cpu")
 def schedule_bitserial_conv2d_nhwc(cfg, outs):
     """Arm cpu schedule for bitserial conv2d"""
     s = tvm.create_schedule([x.op for x in outs])
index 8bd6c5d..3f1889c 100644 (file)
@@ -21,15 +21,13 @@ import tvm
 from tvm import autotvm
 from topi.util import get_const_tuple
 from .. import tag
-from .. import generic
 from .bitserial_conv2d import _intrin_popcount
 from ..nn.pad import pad
-from ..nn.bitserial_dense import bitserial_dense
 from ..nn.bitserial_util import bitpack, binary_op_multiplier
 
-@autotvm.register_topi_compute(bitserial_dense, ['arm_cpu'], 'direct')
-def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype,
-                            unipolar):
+@autotvm.register_topi_compute('bitserial_dense.arm_cpu')
+def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype,
+                    unipolar):
     """The default implementation of bitserial dense in topi.
 
     Parameters
@@ -111,7 +109,7 @@ def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtyp
     return matmul
 
 
-@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['arm_cpu'], 'direct')
+@autotvm.register_topi_schedule('bitserial_dense.arm_cpu')
 def schedule_bitserial_dense(cfg, outs):
     """Schedule for binary_dense.
 
index f0d650a..2144d26 100644 (file)
 """Conv2D schedule for ARM CPU"""
 from __future__ import absolute_import as _abs
 
-import logging
-
 import tvm
 from tvm import autotvm
 import tvm.contrib.nnpack
 
-from ..generic import schedule_conv2d_nchw, schedule_conv2d_nhwc, \
-                      schedule_conv2d_winograd_without_weight_transform, \
-                      schedule_conv2d_winograd_nnpack_without_weight_transform
 from ..util import traverse_inline, get_const_tuple
-from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
-                 conv2d_winograd_without_weight_transform, \
-                 conv2d_winograd_nnpack_without_weight_transform, \
-                 depthwise_conv2d_nchw
+from .. import nn
 from ..nn.util import get_const_int, get_pad_tuple
 from ..nn.winograd_util import winograd_transform_matrices
 from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
@@ -39,75 +31,17 @@ from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
                                  schedule_conv2d_spatial_pack_nchw, \
                                  schedule_conv2d_spatial_pack_nhwc
 
-logger = logging.getLogger('topi')
-
-@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
-def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
-    """TOPI compute callback for conv2d
-
-    Parameters
-    ----------
-    cfg: ConfigEntity
-        The config for this template
-
-    data : tvm.Tensor
-        4-D with shape [batch, in_channel, in_height, in_width]
-
-    kernel : tvm.Tensor
-        4-D with shape [num_filter, in_channel, filter_height, filter_width] or
-        pre-packed 5-D with shape [num_filter_chunk, in_channel, filter_height,
-        filter_width, num_filter_block]
 
-    strides : list of two ints
-        [stride_height, stride_width]
+@autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu")
+def conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    """Compute conv2d with NCHW layout"""
+    return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
+                                    dilation, out_dtype, num_tile=2)
 
-    padding : list of two ints
-        [pad_height, pad_width]
-
-    dilation : list of two ints
-        [dilation_height, dilation_width]
-
-    layout : str
-        layout of data
-
-    out_dtype: str
-        The output type. This is used for mixed precision.
-
-    Returns
-    -------
-    output : tvm.Tensor
-        4-D with shape [batch, out_channel, out_height, out_width]
-    """
-    if layout == 'NCHW':
-        return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
-                                        dilation, out_dtype, num_tile=2)
-    elif layout == 'NHWC':
-        return conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding,
-                                        dilation, out_dtype)
-    else:
-        raise ValueError("Unsupported layout {}".format(layout))
 
-
-@autotvm.register_topi_schedule(
-    schedule_conv2d_nchw, 'arm_cpu',
-    ['direct', 'winograd', 'winograd_nnpack_fp16', 'winograd_nnpack_fp32'])
-def schedule_conv2d_nchw_arm_cpu(cfg, outs):
-    """TOPI schedule callback for conv2d
-
-    Parameters
-    ----------
-    cfg: ConfigEntity
-        The config for this template
-
-    outs: Array of Tensor
-        The computation graph description of conv2d
-        in the format of an array of tensors.
-
-    Returns
-    -------
-    s: Schedule
-        The computation schedule for conv2d.
-    """
+@autotvm.register_topi_schedule("conv2d_nchw_spatial_pack.arm_cpu")
+def schedule_conv2d_nchw_spatial_pack(cfg, outs):
+    """Create schedule for conv2d_nchw"""
     s = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
@@ -131,35 +65,20 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
             schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
                                               conv, output, outs[0])
 
-        if 'winograd_conv2d_output' in op.tag:
-            output = op.output(0)
-            _schedule_winograd(cfg, s, output, outs[0])
-
-        if 'winograd_nnpack_conv2d_output' in op.tag:
-            output = op.output(0)
-            _schedule_winograd_nnpack(cfg, s, output, outs[0])
-
     traverse_inline(s, outs[0].op, _callback)
     return s
 
-@autotvm.register_topi_schedule(schedule_conv2d_nhwc, 'arm_cpu', ['direct'])
-def schedule_conv2d_nhwc_arm_cpu(cfg, outs):
-    """TOPI schedule callback for conv2d
 
-    Parameters
-    ----------
-    cfg: ConfigEntity
-        The config for this template
+@autotvm.register_topi_compute("conv2d_nhwc_spatial_pack.arm_cpu")
+def conv2d_nhwc_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    """Compute conv2d with NHWC layout"""
+    return conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding,
+                                    dilation, out_dtype)
 
-    outs: Array of Tensor
-        The computation graph description of conv2d
-        in the format of an array of tensors.
 
-    Returns
-    -------
-    s: Schedule
-        The computation schedule for conv2d.
-    """
+@autotvm.register_topi_schedule("conv2d_nhwc_spatial_pack.arm_cpu")
+def schedule_conv2d_nhwc_spatial_pack(cfg, outs):
+    """Create schedule for conv2d_nhwc"""
     s = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
@@ -170,14 +89,29 @@ def schedule_conv2d_nhwc_arm_cpu(cfg, outs):
     return s
 
 
-@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
-def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
-    """ TOPI compute callback. Use winograd template """
+@autotvm.register_topi_compute("conv2d_nchw_winograd.arm_cpu")
+def conv2d_nchw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    """Compute conv2d_nchw layout using Winograd with weight transform"""
     tile_size = 4
-    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout,
+    return _decl_winograd(cfg, data, kernel, strides, padding, dilation,
                           out_dtype, tile_size)
 
-def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
+
+@autotvm.register_topi_schedule("conv2d_nchw_winograd.arm_cpu")
+def schedule_conv2d_nchw_winograd(cfg, outs):
+    """Create schedule for conv2d_nchw_winograd"""
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if 'winograd_conv2d_output' in op.tag:
+            output = op.output(0)
+            _schedule_winograd(cfg, s, output, outs[0])
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, tile_size):
     N, CI, IH, IW = get_const_tuple(data.shape)
 
     if isinstance(dilation, int):
@@ -187,7 +121,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
 
     if len(kernel.shape) == 4:
         if dilation_h != 1 or dilation_w != 1:
-            kernel = dilate(kernel, (1, 1, dilation_h, dilation_w))
+            kernel = nn.dilate(kernel, (1, 1, dilation_h, dilation_w))
         pre_computed = False
         CO, _, KH, KW = get_const_tuple(kernel.shape)
     else:
@@ -199,9 +133,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
     pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
 
-    assert layout == 'NCHW'
     assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
-    data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
+    data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
 
     idxd = tvm.indexdiv
     idxm = tvm.indexmod
@@ -272,6 +205,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
     cfg.add_flop(2 * N * K * H * W * KH * KW * C)
     return output
 
+
 def _schedule_winograd(cfg, s, output, last):
     Y = output.op.input_tensors[0]
     M, A = Y.op.input_tensors
@@ -356,26 +290,39 @@ def _schedule_winograd(cfg, s, output, last):
         s[output].compute_inline()
 
 
-@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd_nnpack_fp16'])
-def conv2d_arm_cpu_winograd_nnpack_fp16(
-        cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
-    """ TOPI compute callback. Use winograd_nnpack_fp16 template """
-    return conv2d_arm_cpu_winograd_nnpack(
-        cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
-        tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16)
+@autotvm.register_topi_compute("conv2d_nchw_winograd_nnpack.arm_cpu")
+def conv2d_nchw_winograd_nnpack(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    """Compute conv2d_nchw using nnpack Winograd implementation"""
+    dtype = data.dtype
+    if dtype == "float32":
+        return _conv2d_arm_cpu_winograd_nnpack(
+            cfg, data, kernel, strides, padding, dilation, out_dtype,
+            tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8)
+    elif dtype == "float16":
+        return _conv2d_arm_cpu_winograd_nnpack(
+            cfg, data, kernel, strides, padding, dilation, out_dtype,
+            tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16)
+    else:
+        raise ValueError("Unsupported data type {} for conv2d winograd nnpack".
+                         format(dtype))
+
+
+@autotvm.register_topi_schedule("conv2d_nchw_winograd_nnpack.arm_cpu")
+def schedule_conv2d_nchw_winograd_nnpack(cfg, outs):
+    """Create schedule for conv2d_nchw_winograd_nnpack"""
+    s = tvm.create_schedule([x.op for x in outs])
 
+    def _callback(op):
+        if 'winograd_nnpack_conv2d_output' in op.tag:
+            output = op.output(0)
+            _schedule_winograd_nnpack(cfg, s, output, outs[0])
 
-@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd_nnpack_fp32'])
-def conv2d_arm_cpu_winograd_nnpack_fp32(
-        cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
-    """ TOPI compute callback. Use winograd_nnpack_fp32 template """
-    return conv2d_arm_cpu_winograd_nnpack(
-        cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
-        tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8)
+    traverse_inline(s, outs[0].op, _callback)
+    return s
 
 
-def conv2d_arm_cpu_winograd_nnpack(
-        cfg, data, kernel, strides, padding, dilation, layout, out_dtype, convolution_algorithm):
+def _conv2d_arm_cpu_winograd_nnpack(
+        cfg, data, kernel, strides, padding, dilation, out_dtype, convolution_algorithm):
     """ TOPI compute callback. Use winograd NNPACK template """
     N, CI, IH, IW = get_const_tuple(data.shape)
 
@@ -389,7 +336,6 @@ def conv2d_arm_cpu_winograd_nnpack(
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
     pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
 
-    assert layout == 'NCHW'
     assert KH == 3 and KW == 3 and pt == 1 and pb == 1 and pl == 1 and pr == 1 and HSTR == 1\
         and WSTR == 1
     H = (IH + pt + pb - 3) // HSTR + 1
@@ -416,6 +362,7 @@ def conv2d_arm_cpu_winograd_nnpack(
     cfg.add_flop(2 * N * CI * H * W * KH * KW * CO)
     return output
 
+
 def _schedule_winograd_nnpack(cfg, s, output, last):
     # Could have bias.
 
@@ -429,36 +376,10 @@ def _schedule_winograd_nnpack(cfg, s, output, last):
         s[TK].pragma(s[TK].op.axis[0], 'debug_skip_region')
 
 
-##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
-@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'arm_cpu', ['winograd'])
-def conv2d_winograd_ww(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
-    """TOPI compute callback"""
-    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,\
-                          tile_size)
-
-
-@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
-                                'arm_cpu', ['winograd'])
-def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
-    """TOPI schedule callback"""
-    s = tvm.create_schedule([x.op for x in outs])
-
-    def _callback(op):
-        if 'winograd_conv2d_output' in op.tag:
-            output = op.output(0)
-            _schedule_winograd(cfg, s, output, outs[0])
-
-    traverse_inline(s, outs[0].op, _callback)
-    return s
-
-
-##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD NNPACK WITHOUT WEIGHT TRANSFORM #####
-@autotvm.register_topi_compute(conv2d_winograd_nnpack_without_weight_transform,
-                               'arm_cpu',
-                               ['winograd_nnpack_fp16', 'winograd_nnpack_fp32'])
-def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides,
-                              padding, dilation, layout, out_dtype):
-    """ TOPI compute callback. Use winograd NNPACK template """
+@autotvm.register_topi_compute("conv2d_nchw_winograd_nnpack_without_weight_transform.arm_cpu")
+def conv2d_nchw_winograd_nnpack_without_weight_transform(
+        cfg, data, transformed_kernel, bias, strides, padding, dilation, out_dtype):
+    """Compute conv2d_nchw using NNPack winograd without weight transform"""
     N, CI, IH, IW = get_const_tuple(data.shape)
     if isinstance(dilation, int):
         dilation_h = dilation_w = dilation
@@ -471,7 +392,6 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides,
     KH, KW = 3, 3
     pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
 
-    assert layout == 'NCHW'
     assert KH == 3 and KW == 3 and pt == 1 and pb == 1 and pl == 1 and pr == 1 and HSTR == 1\
         and WSTR == 1
     H = (IH + pt + pb - 3) // HSTR + 1
@@ -492,9 +412,8 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides,
     return output
 
 
-@autotvm.register_topi_schedule(schedule_conv2d_winograd_nnpack_without_weight_transform,
-                                'arm_cpu', ['winograd_nnpack_fp16', 'winograd_nnpack_fp32'])
-def schedule_conv2d_winograd_nnpack_without_weight_transform_(cfg, outs):
+@autotvm.register_topi_schedule("conv2d_nchw_winograd_nnpack_without_weight_transform.arm_cpu")
+def schedule_conv2d_nchw_winograd_nnpack_without_weight_transform(cfg, outs):
     """TOPI schedule callback"""
     s = tvm.create_schedule([x.op for x in outs])
 
@@ -505,226 +424,3 @@ def schedule_conv2d_winograd_nnpack_without_weight_transform_(cfg, outs):
 
     traverse_inline(s, outs[0].op, _callback)
     return s
-
-
-##### REGISTER ALTER OP LAYOUT #####
-@conv2d_alter_layout.register(["arm_cpu"])
-def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
-    """Alter op layout for pre-computing kernel transformation
-
-    Parameters
-    ----------
-    attrs : tvm.ir.Attrs
-        Attributes of current convolution
-    inputs : tvm.relay.Expr
-        Grouped input symbols
-    tinfos : list
-        Input shape and dtype
-    F: symbol
-        The context, can be either relay.op
-
-    Note
-    ----
-    Unlike other TOPI functions, this function operates on both graph level and operator level,
-    so we have to pass 'F' to make it support our two versions of graph IR,  Relay.
-    """
-    copy_inputs = list(inputs)
-    new_attrs = {k: attrs[k] for k in attrs.keys()}
-
-    if F.__name__ == 'tvm.relay.op':
-        # Derive channels for frontends (e.g ONNX) that miss "channel" field.
-        new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
-
-    dilation = attrs.get_int_tuple("dilation")
-    strides = attrs.get_int_tuple("strides")
-    padding = attrs.get_int_tuple("padding")
-    groups = attrs.get_int('groups')
-    data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
-    layout = attrs[data_layout_key]
-    kernel_layout = attrs['kernel_layout']
-    out_dtype = attrs["out_dtype"]
-    if out_dtype in ("same", ""):
-        out_dtype = tinfos[0].dtype
-
-    if dilation != (1, 1):
-        logger.warning("Does not support weight pre-transform for dilated convolution.")
-        return None
-
-    # query config of this workload
-    data, kernel = tinfos[0:2]
-    if groups == 1:
-        workload = autotvm.task.args_to_workload(
-            [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
-    else:
-        workload = autotvm.task.args_to_workload(
-            [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
-
-    if layout == 'NCHW' and kernel_layout == 'OIHW':
-        N, CI, H, W = get_const_tuple(data.shape)
-        CO, _, KH, KW = get_const_tuple(kernel.shape)
-    elif layout == 'NHWC' and kernel_layout == 'HWIO':
-        N, H, W, CI = get_const_tuple(data.shape)
-        KH, KW, _, CO = get_const_tuple(kernel.shape)
-        # Also modify the workload to pick up because later we convert to NCHW
-        # layout.
-        new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
-        new_kernel = tvm.placeholder((CO, CI, KH, KW), dtype=kernel.dtype)
-        new_layout = 'NCHW'
-        workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation, new_layout, out_dtype], conv2d)
-    elif layout == 'NHWC' and kernel_layout == 'HWOI':
-        # This is the case for depthwise convolution.
-        N, H, W, CI = get_const_tuple(data.shape)
-        KH, KW, CO, M = get_const_tuple(kernel.shape)
-        # Also modify the workload to pick up because later we convert to NCHW
-        # layout.
-        new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
-        new_kernel = tvm.placeholder((CO, M, KH, KW), dtype=kernel.dtype)
-        workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
-    else:
-        return None
-
-    idxd = tvm.indexdiv
-
-    if groups == 1:
-        target = tvm.target.Target.current()
-        dispatch_ctx = autotvm.DispatchContext.current
-        cfg = dispatch_ctx.query(target, workload)
-
-        if cfg.is_fallback:  # if is fallback, clear query cache and return None
-            autotvm.task.clear_fallback_cache(target, workload)
-            if layout == 'NHWC' and kernel_layout == 'HWIO':
-                new_attrs['data_layout'] = 'NCHW'
-                new_attrs['kernel_layout'] = 'OIHW'
-                return F.nn.conv2d(*copy_inputs, **new_attrs)
-            return None
-
-        if cfg.template_key == 'direct':  # pack weight tensor
-            VC = cfg['tile_co'].size[-1]
-            new_attrs['kernel_layout'] = 'OIHW%do' % VC
-
-            # Store the same config for the altered operator (workload)
-            new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
-            new_attrs[data_layout_key] = 'NCHW'
-            new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
-            new_workload = autotvm.task.args_to_workload(
-                [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
-            dispatch_ctx.update(target, new_workload, cfg)
-
-            return F.nn.conv2d(*copy_inputs, **new_attrs)
-        elif cfg.template_key == "winograd":  # pre-compute weight transformation in winograd
-            if "-device=arm_cpu" in target.options:
-                tile_size = 4
-                VC = cfg['tile_k'].size[-1]
-            elif "-device=bifrost" in target.options:
-                tile_size = 2
-                VC = 0
-            else:
-                from ..mali.conv2d import _pick_tile_size
-                tile_size = _pick_tile_size(tinfos[0], tinfos[1])
-                VC = cfg['tile_bna'].val
-
-            weight = copy_inputs[1]
-            if kernel_layout != 'OIHW':
-                weight = F.transpose(weight, axes=(2, 3, 0, 1))
-            weight = F.nn.contrib_conv2d_winograd_weight_transform(weight,
-                                                                   tile_size=tile_size)
-            if VC > 0:
-                weight = F.reshape(weight,
-                                   newshape=(KH + tile_size - 1,
-                                             KW + tile_size - 1,
-                                             idxd(CO, VC), VC, CI))
-                weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
-                new_weight = tvm.placeholder((KH + tile_size - 1,
-                                              KW + tile_size -1,
-                                              idxd(CO, VC), CI, VC),
-                                             kernel.dtype)
-            else:
-                weight = F.reshape(weight,
-                                   newshape=(KH + tile_size - 1, KW + tile_size - 1, CO, CI))
-                new_weight = tvm.placeholder(
-                    (KH + tile_size - 1, KW + tile_size -1, CO, CI), kernel.dtype
-                )
-
-            copy_inputs[1] = weight
-            new_attrs['tile_size'] = tile_size
-            new_attrs[data_layout_key] = 'NCHW'
-
-            # Store the same config for the altered operator (workload)
-            new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
-            new_workload = autotvm.task.args_to_workload(
-                [new_data, new_weight, strides, padding, dilation,
-                 new_attrs[data_layout_key], out_dtype, tile_size],
-                conv2d_winograd_without_weight_transform)
-            dispatch_ctx.update(target, new_workload, cfg)
-
-            return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
-        elif cfg.template_key in ["winograd_nnpack_fp16", "winograd_nnpack_fp32"]:
-            # pre-compute winograd_nnpack transform
-            # for winograd_nnpack_fp16, the the precomputeprune pass must run on device,
-            # where float16 is supported
-            weight_dtype = 'float32'
-            weight = copy_inputs[1]
-            if kernel_layout != 'OIHW':
-                weight = F.transpose(weight, axes=(2, 3, 0, 1))
-            weight = F.nn.contrib_conv2d_winograd_weight_transform(weight,
-                                                                   tile_size=tile_size)
-            transformed_kernel = F.nn.contrib_conv2d_winograd_nnpack_weight_transform(
-                weight,
-                convolution_algorithm=cfg['winograd_nnpack_algorithm'].val,
-                out_dtype=weight_dtype)
-            copy_inputs[1] = transformed_kernel
-
-            new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
-            new_kernel = tvm.placeholder((CO, CI, 8, 8), "float32")
-            bias = tvm.placeholder((CO, ), "float32")
-            new_attrs[data_layout_key] = 'NCHW'
-            new_workload = autotvm.task.args_to_workload(
-                [new_data, new_kernel, bias, strides,
-                 padding, dilation, new_attrs[data_layout_key], out_dtype]
-                if len(copy_inputs) == 3 else
-                [new_data, new_kernel, strides,
-                 padding, dilation, new_attrs[data_layout_key], out_dtype],
-                conv2d_winograd_nnpack_without_weight_transform)
-            dispatch_ctx.update(target, new_workload, cfg)
-            return F.nn.contrib_conv2d_winograd_nnpack_without_weight_transform(
-                *copy_inputs, **new_attrs)
-        else:
-            raise RuntimeError("Unsupported template_key '%s'" % cfg.template_key)
-    else:
-        target = tvm.target.Target.current()
-        dispatch_ctx = autotvm.DispatchContext.current
-        cfg = dispatch_ctx.query(target, workload)
-
-        if cfg.is_fallback:  # if is fallback, clear query cache and return None
-            autotvm.task.clear_fallback_cache(tvm.target.Target.current(), workload)
-            if layout == 'NHWC' and kernel_layout == 'HWOI':
-                new_attrs['data_layout'] = 'NCHW'
-                new_attrs['kernel_layout'] = 'OIHW'
-                return F.nn.conv2d(*copy_inputs, **new_attrs)
-            return None
-        if cfg.template_key == 'contrib_spatial_pack':
-            VC = cfg['tile_co'].size[-1]
-            new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
-
-            # Store the same config for the altered operator (workload)
-            new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype)
-            new_attrs[data_layout_key] = 'NCHW'
-            if attrs['kernel_layout'] == 'OIHW':
-                CO, M, KH, KW = get_const_tuple(kernel.shape)
-            elif attrs['kernel_layout'] == 'HWOI':
-                KH, KW, CO, M = get_const_tuple(kernel.shape)
-            else:
-                raise RuntimeError("Depthwise conv should either have OIHW/HWIO kernel layout")
-            new_kernel = tvm.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype)
-            new_workload = autotvm.task.args_to_workload(
-                [new_data, new_kernel, strides, padding, dilation, out_dtype],
-                depthwise_conv2d_nchw)
-            dispatch_ctx.update(target, new_workload, cfg)
-
-            return F.nn.conv2d(*copy_inputs, **new_attrs)
-        else:
-            # currently we only have contrib_spatial_pack and direct template
-            # add more schedule templates.
-            return None
diff --git a/topi/python/topi/arm_cpu/conv2d_alter_op.py b/topi/python/topi/arm_cpu/conv2d_alter_op.py
new file mode 100644 (file)
index 0000000..bfbf5d6
--- /dev/null
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Conv2D alter op and legalize functions for arm cpu"""
+
+import logging
+
+import tvm
+from tvm import relay
+from tvm import autotvm
+
+from ..nn import conv2d_alter_layout
+from ..util import get_const_tuple
+
+
+logger = logging.getLogger('topi')
+
+
+@conv2d_alter_layout.register(["arm_cpu"])
+def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
+    target = tvm.target.Target.current(allow_none=False)
+    dispatch_ctx = autotvm.task.DispatchContext.current
+
+    _, outs = relay.backend.compile_engine.select_implementation(
+        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+    workload = autotvm.task.get_workload(outs)
+    if workload is None:
+        # The best implementation is not an AutoTVM template,
+        # we then assume it's not necessary to alter this op.
+        return None
+    cfg = dispatch_ctx.query(target, workload)
+    if cfg.is_fallback:  # if is fallback, clear query cache and return None
+        autotvm.task.clear_fallback_cache(target, workload)
+        return None
+
+    topi_tmpl = workload[0]
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
+
+    strides = attrs.get_int_tuple("strides")
+    padding = attrs.get_int_tuple("padding")
+    dilation = attrs.get_int_tuple("dilation")
+    data_layout = attrs["data_layout"]
+    kernel_layout = attrs["kernel_layout"]
+    data, kernel = tinfos
+    out_dtype = out_type.dtype
+
+    idxd = tvm.indexdiv
+
+    if topi_tmpl == "conv2d_nchw_spatial_pack.arm_cpu":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+        VC = cfg['tile_co'].size[-1]
+
+        new_attrs['kernel_layout'] = 'OIHW%do' % VC
+
+        new_data = data
+        new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, out_dtype],
+            "conv2d_nchw_spatial_pack.arm_cpu")
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return relay.nn.conv2d(*inputs, **new_attrs)
+
+    if topi_tmpl == "conv2d_nhwc_spatial_pack.arm_cpu":
+        assert data_layout == "NHWC" and kernel_layout == "HWIO"
+        N, H, W, CI = get_const_tuple(data.shape)
+        KH, KW, _, CO = get_const_tuple(kernel.shape)
+        VC = cfg['tile_co'].size[-1]
+
+        new_attrs['kernel_layout'] = 'OHWI%do' % VC
+
+        new_data = data
+        new_kernel = tvm.placeholder((idxd(CO, VC), KH, KW, CI, VC), dtype=kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, out_dtype],
+            "conv2d_nhwc_spatial_pack.arm_cpu")
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return relay.nn.conv2d(*inputs, **new_attrs)
+
+    if topi_tmpl == "conv2d_nchw_winograd.arm_cpu":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+        VC = cfg['tile_k'].size[-1]
+        tile_size = 4
+
+        weight_expr = inputs[1]
+        weight_expr = relay.nn.contrib_conv2d_winograd_weight_transform(
+            weight_expr, tile_size=tile_size)
+        weight_expr = relay.reshape(weight_expr,
+                                    newshape=(KH + tile_size - 1,
+                                              KW + tile_size - 1,
+                                              idxd(CO, VC), VC, CI))
+        weight_expr = relay.transpose(weight_expr, axes=[0, 1, 2, 4, 3])
+
+        new_attrs['tile_size'] = tile_size
+
+        new_data = data
+        new_kernel = tvm.placeholder((KH + tile_size - 1,
+                                      KW + tile_size -1,
+                                      idxd(CO, VC), CI, VC),
+                                     kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, out_dtype],
+            'conv2d_nchw_winograd.arm_cpu')
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return relay.nn.contrib_conv2d_winograd_without_weight_transform(
+            inputs[0], weight_expr, **new_attrs)
+
+    if topi_tmpl == "conv2d_nchw_winograd_nnpack.arm_cpu":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+
+        # pre-compute winograd_nnpack transform
+        # for winograd_nnpack_fp16, the the precompute prune pass must run on device,
+        # where float16 is supported
+        weight_dtype = 'float32'
+        weight_expr = inputs[1]
+        transformed_weight = relay.nn.contrib_conv2d_winograd_nnpack_weight_transform(
+            weight_expr,
+            convolution_algorithm=cfg['winograd_nnpack_algorithm'].val,
+            out_dtype=weight_dtype)
+
+        new_data = data
+        new_kernel = tvm.placeholder((CO, CI, 8, 8), "float32")
+
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, None, strides, padding, dilation, out_dtype],
+            "conv2d_nchw_winograd_nnpack_without_weight_transform.arm_cpu")
+        dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.contrib_conv2d_winograd_without_weight_transform(
+            inputs[0], transformed_weight, **new_attrs)
+
+    if topi_tmpl == "depthwise_conv2d_nchw_spatial_pack.arm_cpu":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+        VC = cfg['tile_co'].size[-1]
+
+        new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
+
+        # Store the same config for the altered operator (workload)
+        new_data = data
+        new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, out_dtype],
+            "depthwise_conv2d_nchw_spatial_pack.arm_cpu")
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return relay.nn.conv2d(*inputs, **new_attrs)
+
+    return None
index 8f43f5c..5d177fe 100644 (file)
@@ -19,9 +19,8 @@
 
 import tvm
 from tvm import autotvm
-from .. import generic, tag
+from .. import tag
 from ..util import get_const_tuple
-from ..nn.conv2d import conv2d_NCHWc_int8
 from ..generic import conv2d as conv2d_generic
 from .. import nn
 from ..nn.conv2d import _get_workload as _get_conv2d_workload
@@ -42,9 +41,10 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
             cfg, wkl, int32_lanes=2, num_int8_elements=4)
 
 
-@autotvm.register_topi_compute(conv2d_NCHWc_int8, ['arm_cpu'], 'direct')
-def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides,
-                                 padding, dilation, layout, out_layout, out_dtype):
+@autotvm.register_topi_compute("conv2d_NCHWc_int8.arm_cpu")
+def conv2d_NCHWc_int8(cfg, data, kernel, strides,
+                      padding, dilation, layout, out_layout, out_dtype):
+    """Compute conv2d int8 with NCHWc layout"""
     # layout and out_layout are not used here,
     # we keep them for debug convenience when dumping autotvm workload
     n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
@@ -68,8 +68,8 @@ def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides,
                                         out_dtype)
 
 
-@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc_int8, ['arm_cpu'], ['direct'])
-def _schedule_conv2d_NCHWc_int8(cfg, outs):
+@autotvm.register_topi_schedule("conv2d_NCHWc_int8.arm_cpu")
+def schedule_conv2d_NCHWc_int8(cfg, outs):
     """Create schedule for tensors"""
     s = tvm.create_schedule([x.op for x in outs])
     scheduled_ops = []
@@ -86,7 +86,7 @@ def _schedule_conv2d_NCHWc_int8(cfg, outs):
 
         if 'conv2d_NCHWc_int8' in op.tag:
             conv_out = op.output(0)
-            kernel = conv_out.op.input_tensors[1]
+            kernel_vec = conv_out.op.input_tensors[1]
             data_vec = conv_out.op.input_tensors[0]
             data = data_vec.op.input_tensors[0] \
                 if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
@@ -95,9 +95,9 @@ def _schedule_conv2d_NCHWc_int8(cfg, outs):
                 data_pad = data
                 data = data_pad.op.input_tensors[0]
 
-            args = [s, cfg, data_vec, conv_out, outs[0]]
+            args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]]
             # int8 conv kernel is 7-dim
-            _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
+            _, _, kh, kw, _, _, _ = get_const_tuple(kernel_vec.shape)
             dtype = "uint" if data.dtype == "uint8" else "int"
             if kh == 1 and kw == 1:
                 conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(
index 350a022..032ac76 100644 (file)
@@ -78,10 +78,12 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
     # fallback support
     if cfg.is_fallback:
         if num_tile == 2:     # arm cpu
-            ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct')
+            ref_log = autotvm.tophub.load_reference_log(
+                'arm_cpu', 'rk3399', 'conv2d_nchw_spatial_pack.arm_cpu')
             cfg.fallback_with_reference_log(ref_log)
         elif num_tile == 3:  # mali gpu
-            ref_log = autotvm.tophub.load_reference_log('mali', 'rk3399', 'conv2d', 'direct')
+            ref_log = autotvm.tophub.load_reference_log(
+                'mali', 'rk3399', 'conv2d_nchw_spatial_pack.mali')
             cfg.fallback_with_reference_log(ref_log)
     # ====================================================================
 
index 65f1024..93ff029 100644 (file)
@@ -21,13 +21,12 @@ from __future__ import absolute_import as _abs
 import tvm
 from tvm import autotvm
 
-from ..generic import schedule_conv2d_transpose_nchw
-from ..nn import conv2d_transpose_nchw, dilate, pad, get_pad_tuple
+from ..nn import dilate, pad, get_pad_tuple
 from ..util import get_const_tuple, traverse_inline
 from .conv2d_spatial_pack import schedule_conv2d_spatial_pack_nchw
 
-@autotvm.task.register_topi_compute(conv2d_transpose_nchw, "arm_cpu", "direct")
-def conv2d_transpose_nchw_arm(cfg, Input, Filter, strides, padding, out_dtype):
+@autotvm.register_topi_compute("conv2d_transpose_nchw.arm_cpu")
+def conv2d_transpose_nchw(cfg, Input, Filter, strides, padding, out_dtype):
     """Transposed 2D convolution nchw forward operator.
 
     Parameters
@@ -135,8 +134,8 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
 
 
 # register customized schedule for arm cpu.
-@autotvm.task.register_topi_schedule(schedule_conv2d_transpose_nchw, "arm_cpu", "direct")
-def schedule_conv2d_transpose_arm(cfg, outs):
+@autotvm.register_topi_schedule("conv2d_transpose_nchw.arm_cpu")
+def schedule_conv2d_transpose_nchw(cfg, outs):
     """Schedule conv2d transpose for arm cpu"""
     s = tvm.create_schedule([x.op for x in outs])
 
index 207fc71..8d668f3 100644 (file)
 import tvm
 from tvm import autotvm
 
-from ..generic import schedule_depthwise_conv2d_nchw
-from ..nn import depthwise_conv2d_nchw, pad
+from .. import nn
 from ..util import traverse_inline, get_const_tuple, get_const_int
 from ..nn.util import get_pad_tuple
 
-# register original implementation of depthwise_conv2d_nchw since we don't need to change this part
-autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
-                              depthwise_conv2d_nchw.fdefault)
 
-# register customized schedule for arm cpu.
-@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu',
-                                ['direct', 'contrib_spatial_pack'])
-def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
+@autotvm.register_topi_compute("depthwise_conv2d_nchw.arm_cpu")
+def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype):
+    """Compute depthwise_conv2d with NCHW layout"""
+    return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
+
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nchw.arm_cpu")
+def schedule_depthwise_conv2d_nchw(cfg, outs):
     """Schedule depthwise conv2d
 
     Parameters
@@ -65,7 +65,7 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
         # fallback support
         if cfg.is_fallback:
             ref_log = autotvm.tophub.load_reference_log(
-                'arm_cpu', 'rk3399', 'depthwise_conv2d_nchw', 'direct')
+                'arm_cpu', 'rk3399', 'depthwise_conv2d_nchw.arm_cpu')
             cfg.fallback_with_reference_log(ref_log)
         ##### space definition end #####
 
@@ -134,25 +134,12 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
                 data = data_pad.op.input_tensors[0]
             _schedule(cfg, s, data, data_pad, kernel, output)
 
-        if op.tag == 'spatial_depthwise_conv2d_nchw_output':
-            output = op.output(0)
-            conv = op.input_tensors[0]
-            data_vec = conv.op.input_tensors[0]
-            kernel_vec = conv.op.input_tensors[1]
-            if kernel_vec.op.name == 'kernel_vec':
-                kernel = kernel_vec.op.input_tensors[0]
-            else:
-                kernel = kernel_vec
-            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
-                s[kernel].compute_inline()
-
-            _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
-
     traverse_inline(s, outs[0].op, _callback)
     return s
 
-@autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', ['contrib_spatial_pack'])
-def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype):
+
+@autotvm.register_topi_compute("depthwise_conv2d_nchw_spatial_pack.arm_cpu")
+def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """TOPI compute callback for depthwise_conv2d nchw
 
     Parameters
@@ -189,6 +176,30 @@ def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
 
+@autotvm.register_topi_schedule("depthwise_conv2d_nchw_spatial_pack.arm_cpu")
+def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == 'spatial_depthwise_conv2d_nchw_output':
+            output = op.output(0)
+            conv = op.input_tensors[0]
+            data_vec = conv.op.input_tensors[0]
+            kernel_vec = conv.op.input_tensors[1]
+            if kernel_vec.op.name == 'kernel_vec':
+                kernel = kernel_vec.op.input_tensors[0]
+            else:
+                kernel = kernel_vec
+            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+                s[kernel].compute_inline()
+            _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
 def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile):
     out_dtype = out_dtype or data.dtype
 
@@ -220,16 +231,16 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
     WPAD = pad_left + pad_right
     DOPAD = (HPAD != 0 or WPAD != 0)
     if DOPAD:
-        data_pad = pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right),
-                       name="data_pad")
+        data_pad = nn.pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right),
+                          name="data_pad")
     else:
         data_pad = data
 
     # fallback support
     # Currently, Mali schedule doesn't use it like conv2d.
     if cfg.is_fallback:
-        ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'depthwise_conv2d_nchw',
-                                                    'contrib_spatial_pack')
+        ref_log = autotvm.tophub.load_reference_log(
+            'arm_cpu', 'rk3399', 'depthwise_conv2d_nchw_spatial_pack.arm_cpu')
         cfg.fallback_with_reference_log(ref_log)
 
     # ==================== define configuration space ====================
index 0b6a16d..644a7e3 100644 (file)
 # pylint: disable=invalid-name, unused-variable
 """Schedule for pooling operators"""
 import tvm
-from .. import generic
 from ..util import is_empty_shape
 
-@generic.schedule_injective_from_existing.register(["arm_cpu"])
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -46,7 +44,6 @@ def schedule_injective_from_existing(sch, out):
         sch[out].parallel(sch[out].op.axis[0])
     return sch
 
-@generic.schedule_injective.register(["arm_cpu"])
 def schedule_injective(outs):
     """ARM CPU schedule for injective op.
 
@@ -74,7 +71,6 @@ def schedule_injective(outs):
         schedule_injective_from_existing(s, x)
     return s
 
-@generic.schedule_concatenate.register(["arm_cpu"])
 def schedule_concatenate(outs):
     """Schedule for concatenate op.
 
index 2ae6580..816024e 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint: disable=invalid-name,unused-variable,unused-argument
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
 """conv2d schedule on ARM Mali (Bifrost) GPU"""
 
 import tvm
+from tvm import relay
 from tvm import autotvm
 
 from .gemm import decl_winograd_gemm, schedule_gemm
 from .transforms import tile_and_bind, tile_and_bind3d
-from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
 from ..util import traverse_inline, get_const_int, get_const_tuple
-from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
-    get_pad_tuple, pad, conv2d_alter_layout, dilate
+from .. import nn
 from ..nn.winograd_util import winograd_transform_matrices
 
 # reuse some compute declarations from ARM CPU
 from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nchw
-from ..arm_cpu.conv2d import _alter_conv2d_layout_arm
 
 
-@autotvm.register_topi_compute(conv2d, 'bifrost', ['direct'])
-def conv2d_bifrost(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
+@autotvm.register_topi_compute("conv2d_nchw_spatial_pack.bifrost")
+def conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """TOPI compute callback for conv2d
 
     Parameters
@@ -60,9 +58,6 @@ def conv2d_bifrost(cfg, data, kernel, strides, padding, dilation, layout, out_dt
     dilation : list of two ints
         [dilation_height, dilation_width]
 
-    layout : str
-        layout of data
-
     out_dtype: str
         The output type. This is used for mixed precision.
 
@@ -71,14 +66,12 @@ def conv2d_bifrost(cfg, data, kernel, strides, padding, dilation, layout, out_dt
     output : tvm.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    if layout == 'NCHW':
-        return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
-                                        dilation, out_dtype, num_tile=3)
-    raise ValueError("Unsupported layout {}".format(layout))
+    return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
+                                    dilation, out_dtype, num_tile=3)
 
 
-@autotvm.register_topi_schedule(schedule_conv2d_nchw, 'bifrost', ['direct', 'winograd'])
-def schedule_conv2d_nchw_bifrost(cfg, outs):
+@autotvm.register_topi_schedule("conv2d_nchw_spatial_pack.bifrost")
+def schedule_conv2d_nchw_spatial_pack(cfg, outs):
     """TOPI schedule callback for conv2d
 
     Parameters
@@ -116,9 +109,6 @@ def schedule_conv2d_nchw_bifrost(cfg, outs):
 
             _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec)
 
-        if 'winograd_conv2d_output' in op.tag:
-            _schedule_winograd(cfg, s, op)
-
     traverse_inline(s, outs[0].op, _callback)
     return s
 
@@ -195,10 +185,22 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec):
     return s
 
 
-@autotvm.register_topi_compute(conv2d, 'bifrost', ['winograd'])
-def conv2d_bifrost_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
+@autotvm.register_topi_compute("conv2d_nchw_winograd.bifrost")
+def conv2d_nchw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """Use Winograd as the convolution method"""
-    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
+    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype)
+
+
+@autotvm.register_topi_schedule("conv2d_nchw_winograd.bifrost")
+def schedule_conv2d_nchw_winograd(cfg, outs):
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if 'winograd_conv2d_output' in op.tag:
+            _schedule_winograd(cfg, s, op)
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
 
 
 def _decl_winograd_kernel_transform(kernel, tile_size, G):
@@ -256,7 +258,7 @@ def _decl_winograd_kernel_transform(kernel, tile_size, G):
     return U
 
 
-def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size=2):
+def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, tile_size=2):
     """Declare a winograd convolution - only tile_size=2 is currently supported"""
     N, CI, IH, IW = get_const_tuple(data.shape)
     if isinstance(dilation, int):
@@ -266,7 +268,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
 
     if int(kernel.shape[2]) == 3:
         if dilation_h != 1 or dilation_w != 1:
-            kernel = dilate(kernel, (1, 1, dilation_h, dilation_w))
+            kernel = nn.dilate(kernel, (1, 1, dilation_h, dilation_w))
         pre_computed = False
         CO, _, KH, KW = get_const_tuple(kernel.shape)
     else:
@@ -275,11 +277,10 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
         H_CAT, W_CAT, CO, CI = get_const_tuple(kernel.shape)
         KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
-    pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
+    pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
 
-    assert layout == 'NCHW'
     assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
-    data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
+    data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
 
     r = KW
     m = tile_size
@@ -454,31 +455,78 @@ def _schedule_winograd(cfg, s, op):
     tile_and_bind3d(s, output, k, h, w, 1, 2, 2)
 
 
-##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
-@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'bifrost', ['winograd'])
-def conv2d_winograd_ww(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
-    """TOPI compute callback"""
-    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
 
+##### REGISTER ALTER OP LAYOUT #####
+@nn.conv2d_alter_layout.register("bifrost")
+def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
+    target = tvm.target.Target.current(allow_none=False)
+    dispatch_ctx = autotvm.task.DispatchContext.current
+
+    _, outs = relay.backend.compile_engine.select_implementation(
+        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+    workload = autotvm.task.get_workload(outs)
+    if workload is None:
+        # The best implementation is not an AutoTVM template,
+        # we then assume it's not necessary to alter this op.
+        return None
+    cfg = dispatch_ctx.query(target, workload)
+    if cfg.is_fallback:  # if is fallback, clear query cache and return None
+        autotvm.task.clear_fallback_cache(target, workload)
+        return None
 
-@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
-                                'bifrost', ['winograd'])
-def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
-    """TOPI schedule callback"""
-    s = tvm.create_schedule([x.op for x in outs])
+    topi_tmpl = workload[0]
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
 
-    def _callback(op):
-        if 'winograd_conv2d_output' in op.tag:
-            _schedule_winograd(cfg, s, op)
+    strides = attrs.get_int_tuple("strides")
+    padding = attrs.get_int_tuple("padding")
+    dilation = attrs.get_int_tuple("dilation")
+    data_layout = attrs["data_layout"]
+    kernel_layout = attrs["kernel_layout"]
+    data, kernel = tinfos
+    out_dtype = out_type.dtype
 
-    traverse_inline(s, outs[0].op, _callback)
-    return s
+    idxd = tvm.indexdiv
 
+    if topi_tmpl == "conv2d_nchw_spatial_pack.bifrost":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+        VC = cfg['tile_co'].size[-1]
 
-##### REGISTER ALTER OP LAYOUT #####
-@conv2d_alter_layout.register(["bifrost"])
-def _alter_conv2d_layout(attrs, inputs, tinfos, F):
-    try:
-        return _alter_conv2d_layout_arm(attrs, inputs, tinfos, F)
-    except KeyError:  # to filter out fallback opencl templates
-        return None
+        new_attrs['kernel_layout'] = 'OIHW%do' % VC
+
+        new_data = data
+        new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, out_dtype],
+            "conv2d_nchw_spatial_pack.bifrost")
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return relay.nn.conv2d(*inputs, **new_attrs)
+
+    if topi_tmpl == "conv2d_nchw_winograd.bifrost":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+        tile_size = 2
+
+        weight_expr = inputs[1]
+        weight_expr = relay.nn.contrib_conv2d_winograd_weight_transform(
+            weight_expr, tile_size=tile_size)
+        weight_expr = relay.reshape(
+            weight_expr, newshape=(KH + tile_size - 1, KW + tile_size - 1, CO, CI))
+
+        new_attrs['tile_size'] = tile_size
+
+        new_data = data
+        new_kernel = tvm.placeholder(
+            (KH + tile_size - 1, KW + tile_size -1, CO, CI), kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, out_dtype],
+            'conv2d_nchw_winograd.bifrost')
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return relay.nn.contrib_conv2d_winograd_without_weight_transform(
+            inputs[0], weight_expr, **new_attrs)
+
+    return None
index 114168f..2a85db7 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=invalid-name,unused-variable
-"""dense schedule on ARM Mali GPU"""
+"""dense schedule on ARM Mali Biforst GPU"""
 
 from __future__ import absolute_import as _abs
 
 import tvm
 from tvm import autotvm
 
-from .. import generic, nn
+from .. import nn
 from ..util import traverse_inline
 
-autotvm.register_topi_compute(nn.dense, 'bifrost', 'direct', nn.dense.fdefault)
+@autotvm.register_topi_compute('dense.biforst')
+def dense(_, data, weight, bias=None, out_dtype=None):
+    """Dense operator on Biforst"""
+    return nn.dense(data, weight, bias, out_dtype)
 
-@autotvm.register_topi_schedule(generic.schedule_dense, 'bifrost', 'direct')
+@autotvm.register_topi_schedule('dense.bifrost')
 def schedule_dense(cfg, outs):
     """Schedule for dense operator.
 
@@ -52,11 +55,11 @@ def schedule_dense(cfg, outs):
             vec_size = [1, 2, 4, 8, 16]
             max_unroll = 32
 
-            dense = op.output(0)
+            dense_out = op.output(0)
             output = outs[0]
 
             y, x = s[output].op.axis
-            c = s[dense].op.reduce_axis[0]
+            c = s[dense_out].op.reduce_axis[0]
 
             ##### space definition begin #####
             cfg.define_split('tile_y', y, num_outputs=3)
@@ -66,12 +69,12 @@ def schedule_dense(cfg, outs):
             # fallback support
             if cfg.is_fallback:
                 ref_log = autotvm.tophub.load_reference_log(
-                    'mali', 'rk3399', 'dense', 'direct')
+                    'mali', 'rk3399', 'dense.bifrost')
                 cfg.fallback_with_reference_log(ref_log)
             ##### space definition end #####
 
-            if dense.op in s.outputs:
-                dense = s.cache_write(output, 'local')
+            if dense_out.op in s.outputs:
+                dense_out = s.cache_write(output, 'local')
 
             by, ty, yi = cfg['tile_y'].apply(s, output, y)
             bx, tx, xi = cfg['tile_x'].apply(s, output, x)
@@ -85,17 +88,17 @@ def schedule_dense(cfg, outs):
                 s[output].unroll(yi)
             if cfg['tile_x'].size[-1] in vec_size:
                 s[output].vectorize(xi)
-            s[dense].compute_at(s[output], tx)
+            s[dense_out].compute_at(s[output], tx)
 
-            k = s[dense].op.reduce_axis[0]
-            y, x = s[dense].op.axis
-            k, k_unroll = cfg['c_unroll'].apply(s, dense, k)
-            s[dense].reorder(k, k_unroll, y, x)
-            s[dense].unroll(k_unroll)
+            k = s[dense_out].op.reduce_axis[0]
+            y, x = s[dense_out].op.axis
+            k, k_unroll = cfg['c_unroll'].apply(s, dense_out, k)
+            s[dense_out].reorder(k, k_unroll, y, x)
+            s[dense_out].unroll(k_unroll)
             if cfg['tile_y'].size[-1] < max_unroll:
-                s[dense].unroll(y)
+                s[dense_out].unroll(y)
             if cfg['tile_x'].size[-1] in vec_size:
-                s[dense].vectorize(x)
+                s[dense_out].vectorize(x)
 
     traverse_inline(s, outs[0].op, _callback)
     return s
index 305abee..4f7b0db 100644 (file)
 from __future__ import absolute_import as _abs
 import tvm
 
-from .. import generic
 from .. import util
 from .. import tag
 
-@generic.schedule_depthwise_conv2d_nchw.register(["bifrost"])
 def schedule_depthwise_conv2d_nchw(outs):
     """Schedule for depthwise_conv2d nchw forward.
 
index 4c20dd0..6e38318 100644 (file)
 """CUDA specific declaration and schedules."""
 from __future__ import absolute_import as _abs
 
-from . import conv1d, conv2d, depthwise_conv2d, conv2d_transpose_nchw, \
-              deformable_conv2d, group_conv2d_nchw, dense, conv1d_transpose_ncw
-from . import conv3d
-from .conv2d_hwcn import schedule_conv2d_hwcn
-from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
-from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
-from .group_conv2d_nchw import schedule_conv2d_nchw_cuda
+from .conv1d import *
+from .conv1d_transpose_ncw import *
+from .conv2d import *
+from .conv2d_hwcn import *
+from .conv2d_int8 import *
+from .conv2d_winograd import *
+from .depthwise_conv2d import *
+from .group_conv2d_nchw import *
+from . import conv2d_alter_op
+from .conv2d_transpose_nchw import *
+from .deformable_conv2d import *
+from .conv3d import *
 from .reduction import schedule_reduce
 from .softmax import schedule_softmax
 from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
-from .dense import schedule_dense
-from .pooling import schedule_pool, schedule_adaptive_pool
+from .dense import *
+from .pooling import *
 from .nn import schedule_lrn
-from .batch_matmul import schedule_batch_matmul
+from .batch_matmul import *
 from .vision import *
-from . import ssd
 from .ssd import *
-from .nms import *
+from .nms import get_valid_counts, non_max_suppression
 from .rcnn import *
 from .sort import *
index 24fc2a1..e293c7a 100644 (file)
 from __future__ import absolute_import as _abs
 import tvm
 from tvm.contrib import cublas
-from topi.nn import batch_matmul, batch_matmul_default
-from .. import generic
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
-@batch_matmul.register(["cuda", "gpu"])
-def batch_matmul_cuda(x, y):
-    """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
-    data in batch.
-
-    Parameters
-    ----------
-    x : tvm.Tensor
-        3-D with shape [batch, M, K]
-
-    y : tvm.Tensor
-        3-D with shape [batch, N, K]
-
-    Returns
-    -------
-    output : tvm.Tensor
-        3-D with shape [batch, M, N]
-    """
-    target = tvm.target.Target.current()
-    if target.target_name == "cuda" and "cublas" in target.libs:
-        return cublas.batch_matmul(x, y, False, True)
-    return batch_matmul_default(x, y)
-
-@generic.schedule_batch_matmul.register(["cuda", "gpu"])
 def schedule_batch_matmul(outs):
     """Schedule for batch_matmul
 
@@ -61,10 +35,6 @@ def schedule_batch_matmul(outs):
     s: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.Target.current()
-    if target.target_name == "cuda" and "cublas" in target.libs:
-        return generic.schedule_extern(outs)
-
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
 
@@ -134,3 +104,22 @@ def schedule_batch_matmul(outs):
 
     traverse_inline(s, outs[0].op, _callback)
     return s
+
+def batch_matmul_cublas(x, y):
+    """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+    data in batch.
+
+    Parameters
+    ----------
+    x : tvm.Tensor
+        3-D with shape [batch, M, K]
+
+    y : tvm.Tensor
+        3-D with shape [batch, N, K]
+
+    Returns
+    -------
+    output : tvm.Tensor
+        3-D with shape [batch, M, N]
+    """
+    return cublas.batch_matmul(x, y, False, True)
index 43754a3..56918e2 100644 (file)
 import tvm
 from tvm import autotvm
 
-from .. import nn, generic
+from .. import nn
 from ..util import traverse_inline, get_const_tuple
 
 
-@autotvm.register_topi_compute(nn.conv1d, ['cuda', 'gpu'], ['direct'])
-def conv1d_cuda(cfg,
-                data,
-                kernel,
-                strides,
-                padding,
-                dilation,
-                layout='NCW',
-                out_dtype='float32'):
-    """ 1D convolution forward operator for cuda backend.
+@autotvm.register_topi_compute("conv1d_ncw.cuda")
+def conv1d_ncw(cfg,
+               data,
+               kernel,
+               strides,
+               padding,
+               dilation,
+               out_dtype='float32'):
+    return nn.conv1d_ncw(data, kernel, strides, padding, dilation, out_dtype)
 
-    Parameters
-    ----------
-    cfg : ConfigEntity
-        The config for this template
-
-    data : tvm.Tensor
-        3-D input shape [batch, in_channel, in_width] for layout == 'NCW'
-        and [batch, in_width, in_channel] for layout == 'NWC'
-
-    kernel : tvm.Tensor
-        3-D kernel with shape [num_filter, in_channel, filter_size] for layout == 'NCW'
-        and [filter_size, in_channel, num_filter] for layout == 'NWC'
-
-    strides : int or tuple
-        The spatial stride along width
 
-    padding : int or str
-        Padding size, or ['VALID', 'SAME']
-
-    dilation : int or tuple
-        Dilation rate if convolution should be dilated.
-
-    layout : str
-        How input data is laid out, must be one of ['NCW', 'NWC']
-
-    out_dtype : str
-        The output data type. If None then output is same type as input.
-    """
-    if out_dtype is None:
-        out_dtype = data.dtype
-    if isinstance(strides, (tuple, list)):
-        strides = strides[0]
-    if isinstance(dilation, (tuple, list)):
-        dilation = dilation[0]
-
-    if layout == 'NCW':
-        return nn.conv1d_ncw(data, kernel, strides, padding, dilation,
-                             out_dtype)
-    if layout == 'NWC':
-        return nn.conv1d_nwc(data, kernel, strides, padding, dilation,
-                             out_dtype)
-    raise ValueError("This layout is not yet supported: {}".format(layout))
-
-
-@autotvm.register_topi_schedule(generic.schedule_conv1d_ncw, ["cuda", "gpu"],
-                                ["direct"])
+@autotvm.register_topi_schedule("conv1d_ncw.cuda")
 def schedule_conv1d_ncw(cfg, outs):
     """TOPI schedule callback of conv1d ncw for cuda gpu
 
@@ -193,8 +148,18 @@ def schedule_conv1d_ncw(cfg, outs):
     return s
 
 
-@autotvm.register_topi_schedule(generic.schedule_conv1d_nwc, ["cuda", "gpu"],
-                                ["direct"])
+@autotvm.register_topi_compute("conv1d_nwc.cuda")
+def conv1d_nwc(cfg,
+               data,
+               kernel,
+               strides,
+               padding,
+               dilation,
+               out_dtype='float32'):
+    return nn.conv1d_nwc(data, kernel, strides, padding, dilation, out_dtype)
+
+
+@autotvm.register_topi_schedule("conv1d_nwc.cuda")
 def schedule_conv1d_nwc(cfg, outs):
     """TOPI schedule callback of conv1d nwc for cuda gpu
 
index 4cedbd5..4802a0d 100644 (file)
 
 import tvm
 from tvm import autotvm
-from .. import nn, generic
+from .. import nn
 from ..util import get_const_tuple, traverse_inline
 
-@autotvm.task.register_topi_compute(nn.conv1d_transpose_ncw, ['cuda', 'gpu'], "direct")
-def conv1d_transpose_ncw_cuda(cfg, data, kernel, stride, padding, out_dtype):
+@autotvm.task.register_topi_compute("conv1d_transpose_nchw.cuda")
+def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype):
     """Transposed 1D convolution ncw forward operator.
 
     Parameters
@@ -79,9 +79,8 @@ def conv1d_transpose_ncw_cuda(cfg, data, kernel, stride, padding, out_dtype):
 
     return data_out
 
-@autotvm.task.register_topi_schedule(generic.schedule_conv1d_transpose_ncw,
-                                     ['cuda', 'gpu'], 'direct')
-def schedule_conv1d_transpose_ncw_cuda(cfg, outs):
+@autotvm.task.register_topi_schedule("conv1d_transpose_nchw.cuda")
+def schedule_conv1d_transpose_ncw(cfg, outs):
     """TOPI Schedule callback for conv1d_transpose operator.
 
     Parameters
index f26069c..e1ada32 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, unused-argument
 """Compute definition for conv2d with cuda backend"""
 import tvm
 from tvm import autotvm
@@ -23,179 +23,95 @@ from tvm.contrib import cudnn
 from .. import nn, generic
 from ..nn.util import get_pad_tuple
 from ..util import get_const_tuple, traverse_inline
-
 from .conv2d_direct import schedule_direct_cuda
-from .conv2d_winograd import winograd_cuda, schedule_winograd_cuda
-from .conv2d_int8 import conv2d_NCHWc_int8, schedule_conv2d_NCHWc_int8
-
-
-@autotvm.register_topi_compute(nn.conv2d, ['cuda', 'gpu'], ['direct', 'winograd', 'int8'])
-def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', out_dtype='float32'):
-    """Conv2D operator for cuda backend.
-
-    Parameters
-    ----------
-    cfg: ConfigEntity
-        The config for this template
-
-    data : tvm.Tensor
-        4-D with shape [batch, in_channel, in_height, in_width] or
-        5-D with shape [batch, ic_chunk, in_height, in_width, ic_block]
-
-    kernel : tvm.Tensor
-        4-D with shape [num_filter, in_channel, filter_height, filter_width] or
-        6-D with shape [num_filter_chunk, in_channel_chunk, filter_height,
-        filter_width, num_filter_block, in_channel_block]
-
-    strides : int or a list/tuple of two ints
-        stride size, or [stride_height, stride_width]
-
-    padding : int or a list/tuple of 2 or 4 ints
-        padding size, or
-        [pad_height, pad_width] for 2 ints, or
-        [pad_top, pad_left, pad_bottom, pad_right] for 4 ints
 
-    dilation: int or a list/tuple of two ints
-        dilation size, or [dilation_height, dilation_width]
 
-    layout : str
-        layout of data
+@autotvm.register_topi_compute("conv2d_nchw.cuda")
+def conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
+    """Compute conv2d with NCHW layout"""
+    return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
 
-    out_dtype: str
-        The output type. This is used for mixed precision.
-
-    Returns
-    -------
-    output : tvm.Tensor
-        4-D with shape [batch, out_channel, out_height, out_width]
-    """
-    target = tvm.target.Target.current()
-
-    if "cudnn" in target.libs:
-        if layout == 'NCHW':
-            tensor_format = 0 # CUDNN_TENSOR_NCHW
-            N, _, H, W = get_const_tuple(data.shape)
-        elif layout == 'NHWC':
-            tensor_format = 1 # CUDNN_TENSOR_NHWC
-            N, H, W, _ = get_const_tuple(data.shape)
-        else:
-            raise ValueError("Unsupported layout %s in cudnn" % layout)
-        CO, CI, KH, KW = get_const_tuple(kernel.shape)
-
-        # handle dilation
-        stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
-        dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
-
-        if isinstance(padding, (list, tuple)) and len(padding) == 4 and \
-           (padding[0] != padding[2] or padding[1] != padding[3]):
-            raise ValueError("Cudnn doesn't support asymmetric padding.")
-        pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
-        OH = (H + pt + pb - KH) // stride_h + 1
-        OW = (W + pl + pr - KW) // stride_w + 1
-        cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
-                    ((KW - 1) * dilation_w + 1))
-
-        if data.dtype == "int8" or kernel.dtype == "int8":
-            if layout == 'NCHW':
-                raise ValueError("NCHW layout do not support int8 in cudnn")
-            dtype = "int32"
-        else:
-            dtype = data.dtype
-
-        return cudnn.conv_forward(data,
-                                  kernel,
-                                  [pt, pl], # cudnn padding pt, pl on both sides of input
-                                  [stride_h, stride_w],
-                                  [dilation_h, dilation_w],
-                                  conv_mode=1,
-                                  tensor_format=tensor_format,
-                                  algo=-1,         # let CUDNN choose the best algo
-                                  conv_dtype=dtype)
-
-    if cfg.template_key == 'winograd':
-        return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
-                             pre_computed=False)
-    if cfg.template_key == 'int8':
-        if (data.dtype == 'int8' or data.dtype == 'uint8'):
-            return conv2d_NCHWc_int8(
-                cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
-
-    if layout == 'NCHW':
-        return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
-    if layout == 'HWCN':
-        return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
-    if layout == 'NHWC':
-        return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
-    raise ValueError("not support this layout {} yet".format(layout))
-
-
-@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, ["cuda", "gpu"],
-                                ["direct", 'winograd', "int8"])
-def schedule_conv2d_nchw_cuda(cfg, outs):
-    """TOPI schedule callback of conv2d for cuda gpu
-
-    Parameters
-    ----------
-    cfg: ConfigEntity
-        The config for this template
-
-    outs: Array of Tensor
-        The computation graph description of conv2d
-        in the format of an array of tensors.
-
-    Returns
-    -------
-    s: Schedule
-        The computation schedule for conv2d.
-    """
-    target = tvm.target.Target.current()
-    if 'cudnn' in target.libs:
-        return generic.schedule_extern(outs)
 
+@autotvm.register_topi_schedule("conv2d_nchw.cuda")
+def schedule_conv2d_nchw(cfg, outs):
+    """Create the schedule for conv2d_nchw"""
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
         if op.tag == 'conv2d_nchw':
             schedule_direct_cuda(cfg, s, op.output(0))
-        if op.tag == 'conv2d_nchw_winograd':
-            schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False)
-        if op.tag == "conv2d_NCHWc_int8":
-            schedule_conv2d_NCHWc_int8(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
-@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc, ["cuda", "gpu"],
-                                ["direct"])
-def schedule_conv2d_nhwc_cuda(cfg, outs):
-    """TOPI schedule for CUDA conv2d_nhwc
-
-    Parameters
-    ----------
-    cfg: ConfigEntity
-        The config for this template
-
-    outs: Array of Tensor
-        The computation graph description of conv2d
-        in the format of an array of tensors.
-
-    Returns
-    -------
-    s: Schedule
-        The computation schedule for conv2d.
-    """
-    target = tvm.target.Target.current()
-    if 'cudnn' in target.libs:
-        return generic.schedule_extern(outs)
-
-    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    s = tvm.create_schedule([x.op for x in outs])
+# TODO(@alexgl-github): It's invalid to call schedule_direct_cuda for NHWC layout
+#  as it assumes the input layout to be NCHW. Please fix this.
+# @autotvm.register_topi_compute("conv2d_nhwc.cuda")
+# def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
+#     return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
+#
+#
+# @autotvm.register_topi_schedule("conv2d_nhwc.cuda")
+# def schedule_conv2d_nhwc(cfg, outs):
+#     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+#     s = tvm.create_schedule([x.op for x in outs])
+#
+#     def _callback(op):
+#         if op.tag == 'conv2d_nhwc':
+#             schedule_direct_cuda(cfg, s, op.output(0))
+#
+#     traverse_inline(s, outs[0].op, _callback)
+#     return s
 
-    def _callback(op):
-        if op.tag == 'conv2d_nhwc':
-            schedule_direct_cuda(cfg, s, op.output(0))
 
-    traverse_inline(s, outs[0].op, _callback)
-    return s
+@autotvm.register_topi_compute("conv2d_cudnn.cuda")
+def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW',
+                 out_dtype='float32'):
+    """Compute conv2d using CuDNN library"""
+    if layout == 'NCHW':
+        tensor_format = 0 # CUDNN_TENSOR_NCHW
+        N, _, H, W = get_const_tuple(data.shape)
+    elif layout == 'NHWC':
+        tensor_format = 1 # CUDNN_TENSOR_NHWC
+        N, H, W, _ = get_const_tuple(data.shape)
+    else:
+        raise ValueError("Unsupported layout %s in cudnn" % layout)
+    CO, CI, KH, KW = get_const_tuple(kernel.shape)
+
+    # handle dilation
+    stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
+    dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
+
+    if isinstance(padding, (list, tuple)) and len(padding) == 4 and \
+            (padding[0] != padding[2] or padding[1] != padding[3]):
+        raise ValueError("Cudnn doesn't support asymmetric padding.")
+    pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
+    OH = (H + pt + pb - KH) // stride_h + 1
+    OW = (W + pl + pr - KW) // stride_w + 1
+    cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) * \
+                 ((KW - 1) * dilation_w + 1))
+
+    if data.dtype == "int8" or kernel.dtype == "int8":
+        if layout == 'NCHW':
+            raise ValueError("NCHW layout do not support int8 in cudnn")
+        dtype = "int32"
+    else:
+        dtype = data.dtype
+
+    return cudnn.conv_forward(data,
+                              kernel,
+                              [pt, pl], # cudnn padding pt, pl on both sides of input
+                              [stride_h, stride_w],
+                              [dilation_h, dilation_w],
+                              conv_mode=1,
+                              tensor_format=tensor_format,
+                              algo=-1,         # let CUDNN choose the best algo
+                              conv_dtype=dtype)
+
+
+@autotvm.register_topi_schedule("conv2d_cudnn.cuda")
+def schedule_conv2d_cudnn(cfg, outs):
+    """Create the schedule for conv2d_cudnn"""
+    return generic.schedule_extern(outs)
diff --git a/topi/python/topi/cuda/conv2d_alter_op.py b/topi/python/topi/cuda/conv2d_alter_op.py
new file mode 100644 (file)
index 0000000..f3e4f4c
--- /dev/null
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Conv2D alter op and legalize functions for cuda backend"""
+
+import logging
+import tvm
+from tvm import relay
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_tuple
+from .conv2d_winograd import _infer_tile_size
+
+logger = logging.getLogger('topi')
+
+@nn.conv2d_alter_layout.register(["cuda", "gpu"])
+def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
+    target = tvm.target.Target.current(allow_none=False)
+    dispatch_ctx = autotvm.task.DispatchContext.current
+
+    _, outs = relay.backend.compile_engine.select_implementation(
+        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+    workload = autotvm.task.get_workload(outs)
+    if workload is None:
+        # The best implementation is not an AutoTVM template,
+        # we then assume it's not necessary to alter this op.
+        return None
+    cfg = dispatch_ctx.query(target, workload)
+    if cfg.is_fallback:  # if is fallback, clear query cache and return None
+        autotvm.task.clear_fallback_cache(target, workload)
+        return None
+
+    topi_tmpl = workload[0]
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
+
+    strides = attrs.get_int_tuple("strides")
+    padding = attrs.get_int_tuple("padding")
+    dilation = attrs.get_int_tuple("dilation")
+    groups = attrs.get_int('groups')
+    data_layout = attrs["data_layout"]
+    kernel_layout = attrs["kernel_layout"]
+    data, kernel = tinfos
+    out_dtype = out_type.dtype
+
+    if topi_tmpl == "conv2d_NCHWc_int8.cuda":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+
+        new_layout = 'NCHW4c'
+        new_attrs["channels"] = CO
+        new_attrs["data_layout"] = new_layout
+        new_attrs['out_layout'] = new_layout
+        new_attrs['kernel_layout'] = 'OIHW4o4i'
+        ic_block_factor = oc_block_factor = 4
+
+        # Store the same config for the altered operator (workload)
+        new_data = tvm.placeholder((N, CI // ic_block_factor, H, W, ic_block_factor),
+                                   dtype=data.dtype)
+        new_kernel = tvm.placeholder((CO // oc_block_factor, CI // ic_block_factor, KH, KW, \
+                                      oc_block_factor, ic_block_factor), dtype=kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, new_layout, out_dtype],
+            "conv2d_NCHWc_int8.cuda")
+        dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.conv2d(*inputs, **new_attrs)
+
+    if topi_tmpl == "conv2d_nchw_winograd.cuda":
+        if dilation != (1, 1):
+            logger.warning("Does not support weight pre-transform for dilated convolution.")
+            return None
+
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+
+        # pre-compute weight transformation in winograd
+        tile_size = _infer_tile_size(tinfos[0], tinfos[1])
+
+        weight = relay.nn.contrib_conv2d_winograd_weight_transform(inputs[1],
+                                                                   tile_size=tile_size)
+        weight = relay.transpose(weight, axes=[0, 1, 3, 2])
+        new_attrs['tile_size'] = tile_size
+        new_attrs['channels'] = CO
+
+        # Store the same config for the altered operator (workload)
+        new_data = data
+        new_weight = tvm.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO),
+                                     dtype=kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_weight, strides, padding, dilation, out_dtype],
+            "conv2d_nchw_winograd_without_weight_transform.cuda")
+        dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.contrib_conv2d_winograd_without_weight_transform(
+            inputs[0], weight, **new_attrs)
+
+    if topi_tmpl == "group_conv2d_NCHWc_int8.cuda":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+
+        new_layout = 'NCHW4c'
+        new_attrs["channels"] = CO
+        new_attrs["data_layout"] = new_layout
+        new_attrs['out_layout'] = new_layout
+        new_attrs['kernel_layout'] = 'OIHW4o4i'
+        ic_block_factor = oc_block_factor = 4
+
+        # Store the same config for the altered operator (workload)
+        new_data = tvm.placeholder((N, CI // ic_block_factor, H, W, ic_block_factor),
+                                   dtype=data.dtype)
+        new_kernel = tvm.placeholder((CO // oc_block_factor, CI // ic_block_factor // groups,
+                                      KH, KW, oc_block_factor, ic_block_factor),
+                                     dtype=kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, groups, out_dtype],
+            "group_conv2d_NCHWc_int8.cuda")
+        dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.conv2d(*inputs, **new_attrs)
+
+    return None
index b7df885..2fab8cf 100644 (file)
@@ -43,7 +43,7 @@ def schedule_direct_cuda(cfg, s, conv):
     # fallback support
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.target_name, target.model, 'conv2d', 'direct')
+            target.target_name, target.model, 'conv2d_nchw.cuda')
         cfg.fallback_with_reference_log(ref_log)
     ##### space definition end #####
 
index 18a624a..b0925ae 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, too-many-locals, too-many-statements
+# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument
 """Schedule for conv2d_hwcn with auto fusion"""
 import tvm
 from tvm import autotvm
+
 from tvm.autotvm.task.space import SplitEntity
 
-from .. import generic, tag
+from .. import nn, tag
+
+@autotvm.register_topi_compute("conv2d_hwcn.cuda")
+def conv2d_hwcn(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
+    """Compute conv2d with HWCN layout on CUDA"""
+    return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
 
 
-@autotvm.register_topi_schedule(generic.schedule_conv2d_hwcn, ["cuda", "gpu"], ["direct"])
+@autotvm.register_topi_schedule("conv2d_hwcn.cuda")
 def schedule_conv2d_hwcn(cfg, outs):
     """Schedule for conv2d_hwcn and any element-wise operations.
 
index 580cf96..53a7bd9 100644 (file)
@@ -23,9 +23,10 @@ from .injective import schedule_injective_from_existing
 from .tensor_intrin import dp4a
 from ..nn.pad import pad
 from ..nn.util import get_pad_tuple
-from ..util import get_const_tuple
+from ..util import get_const_tuple, traverse_inline
 
 
+@autotvm.register_topi_compute("conv2d_NCHWc_int8.cuda")
 def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_dtype):
     """Convolution operator in NCHW[x]c layout for int8.
 
@@ -152,8 +153,21 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_
 _dp4a = dp4a('shared', 'shared', 'local')
 
 
-def schedule_conv2d_NCHWc_int8(cfg, s, output):
+@autotvm.register_topi_schedule("conv2d_NCHWc_int8.cuda")
+def schedule_conv2d_NCHWc_int8(cfg, outs):
     """Schedule conv2d int8 NCHWc template"""
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == 'conv2d_NCHWc_int8':
+            _schedule_conv2d_NCHWc_int8(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def _schedule_conv2d_NCHWc_int8(cfg, s, output):
     conv = output.op.input_tensors[0]
     packed_data, packed_kernel = conv.op.input_tensors
 
index be9f315..8751800 100644 (file)
 import tvm
 from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
-from .. import nn, generic
+from .. import nn
 from ..util import get_const_tuple, traverse_inline
 
 
-@autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct")
-def conv2d_transpose_nchw_cuda(cfg, data, kernel, stride, padding, out_dtype):
+@autotvm.register_topi_compute("conv2d_transpose_nchw.cuda")
+def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype):
     """Transposed 2D convolution nchw forward operator.
 
     Parameters
@@ -101,9 +101,8 @@ def conv2d_transpose_nchw_cuda(cfg, data, kernel, stride, padding, out_dtype):
 
     return data_out
 
-@autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw,
-                                     ['cuda', 'gpu'], 'direct')
-def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
+@autotvm.register_topi_schedule("conv2d_transpose_nchw.cuda")
+def schedule_conv2d_transpose_nchw(cfg, outs):
     """TOPI Schedule callback for conv2d transpose operator.
 
     Parameters
index 37307d6..6e09be9 100644 (file)
@@ -22,9 +22,7 @@ import tvm
 from tvm import autotvm
 
 from .. import nn
-from ..nn import conv2d, group_conv2d_nchw, conv2d_winograd_without_weight_transform
 from ..util import get_const_int, get_const_tuple, traverse_inline
-from ..generic import schedule_conv2d_winograd_without_weight_transform
 from ..nn.winograd_util import winograd_transform_matrices
 
 
@@ -37,10 +35,9 @@ def _infer_tile_size(data, kernel):
         return 4
     return 2
 
-def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, pre_computed):
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                  pre_computed):
     """Compute declaration for winograd"""
-    assert layout == 'NCHW'
-
     tile_size = _infer_tile_size(data, kernel)
 
     N, CI, H, W = get_const_tuple(data.shape)
@@ -53,7 +50,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
 
     if not pre_computed: # kernel tensor is raw tensor, do strict check
         if dilation_h != 1 or dilation_w != 1:
-            kernel = dilation(kernel, (1, 1, dilation_h, dilation_w))
+            kernel = nn.dilate(kernel, (1, 1, dilation_h, dilation_w))
         CO, CI, KH, KW = get_const_tuple(kernel.shape)
         alpha = KW + tile_size - 1
         assert HSTR == 1 and WSTR == 1 and KH == KW
@@ -282,161 +279,38 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
 
     return s
 
-##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
-@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform,
-                               ['cuda', 'gpu'], ['winograd'])
-def conv2d_winograd_ww(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
-    return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
-                         pre_computed=True)
-
+@autotvm.register_topi_compute("conv2d_nchw_winograd.cuda")
+def conv2d_nchw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    return winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                         pre_computed=False)
 
-@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
-                                ['cuda', 'gpu'], ['winograd'])
-def schedule_conv2d_winograd_without_weight_transform_cuda(cfg, outs):
-    """TOPI schedule callback"""
+@autotvm.register_topi_schedule("conv2d_nchw_winograd.cuda")
+def schedule_conv2d_nchw_winograd(cfg, outs):
     s = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
         if 'conv2d_nchw_winograd' in op.tag:
-            schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=True)
+            schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False)
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
-##### REGISTER ALTER OP LAYOUT #####
-@nn.conv2d_alter_layout.register(["cuda", "gpu"])
-def _alter_conv2d_layout(attrs, inputs, tinfos, F):
-    """Alter op layout for pre-computing kernel transformation
-
-    Parameters
-    ----------
-    attrs : tvm.ir.Attrs
-        Attributes of current convolution
-    inputs : tvm.relay.Expr
-        Grouped input symbols
-    tinfos : list
-        Input shape and dtype
-    F: symbol
-        The context, can be relay.op
-
-    Note
-    ----
-    Unlike other TOPI functions, this function operates on both graph level and operator level,
-    so we have to pass 'F' to make it support our two versions of graph IR,  Relay.
-    """
-    if 'cudnn' in tvm.target.Target.current().libs or 'miopen' in tvm.target.Target.current().libs:
-        return None
-
-    copy_inputs = list(inputs)
-    new_attrs = {k: attrs[k] for k in attrs.keys()}
-
-
-    new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
-
-    strides = attrs.get_int_tuple("strides")
-    padding = attrs.get_int_tuple("padding")
-    dilation = attrs.get_int_tuple("dilation")
-    groups = attrs.get_int('groups')
-    data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
-    layout = attrs[data_layout_key]
-    out_dtype = attrs["out_dtype"]
-    if out_dtype in ("", "same"):
-        out_dtype = tinfos[0].dtype
-
-    data, kernel = tinfos[0:2]
-    N, CI, H, W = get_const_tuple(data.shape)
-    CO, _, KH, KW = get_const_tuple(kernel.shape)
+@autotvm.register_topi_compute("conv2d_nchw_winograd_without_weight_transform.cuda")
+def conv2d_nchw_winograd_without_weight_transform(cfg, data, kernel, strides,
+                                                  padding, dilation, out_dtype):
+    return winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                         pre_computed=True)
 
-    dispatch_ctx = autotvm.DispatchContext.current
-    target = tvm.target.Target.current()
 
-    if groups == 1:
-        # query config of this workload
-        workload = autotvm.task.args_to_workload(
-            [tinfos[0], tinfos[1], strides, padding, dilation, layout, out_dtype], conv2d)
-        cfg = autotvm.DispatchContext.current.query(target, workload)
-
-        if cfg.is_fallback:  # if is fallback, clear query cache and return None
-            autotvm.task.clear_fallback_cache(target, workload)
-            return None
-
-        if cfg.template_key == 'direct':
-            return None
-
-        if cfg.template_key == 'int8':
-            assert 'cuda' in target.keys
-            new_layout = 'NCHW4c'
-            new_attrs[data_layout_key] = new_layout
-            new_attrs['out_layout'] = new_layout
-            new_attrs['kernel_layout'] = 'OIHW4o4i'
-            ic_block_factor = oc_block_factor = 4
-
-            # Store the same config for the altered operator (workload)
-            new_data = tvm.placeholder((N, CI // ic_block_factor, H, W, ic_block_factor),
-                                       dtype=data.dtype)
-            new_kernel = tvm.placeholder((CO // oc_block_factor, CI // ic_block_factor, KH, KW,\
-                                         oc_block_factor, ic_block_factor), dtype=kernel.dtype)
-            new_workload = autotvm.task.args_to_workload(
-                [new_data, new_kernel, strides, padding, dilation, new_layout, out_dtype],
-                conv2d
-            )
-            dispatch_ctx.update(target, new_workload, cfg)
-            return F.nn.conv2d(*copy_inputs, **new_attrs)
-
-        if attrs.get_int_tuple("dilation") != (1, 1):
-            logger.warning("Does not support weight pre-transform for dilated convolution.")
-            return None
-
-        # pre-compute weight transformation in winograd
-        tile_size = _infer_tile_size(tinfos[0], tinfos[1])
-
-        weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
-                                                               tile_size=tile_size)
-        weight = F.transpose(weight, axes=[0, 1, 3, 2])
-        copy_inputs[1] = weight
-        new_attrs['tile_size'] = tile_size
-
-        # Store the same config for the altered operator (workload)
-        new_data = data
-        new_weight = tvm.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO),
-                                     dtype=kernel.dtype)
-        new_workload = autotvm.task.args_to_workload(
-            [new_data, new_weight, strides, padding, dilation, layout, out_dtype, tile_size],
-            conv2d_winograd_without_weight_transform
-        )
-        dispatch_ctx.update(target, new_workload, cfg)
-        return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
-    if groups != CI:
-        workload = autotvm.task.args_to_workload(
-            [tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype],
-            group_conv2d_nchw)
-        cfg = autotvm.DispatchContext.current.query(target, workload)
-
-        if cfg.is_fallback:  # if is fallback, clear query cache and return None
-            autotvm.task.clear_fallback_cache(target, workload)
-            return None
-
-        if cfg.template_key == 'int8':
-            assert 'cuda' in target.keys
-            new_layout = 'NCHW4c'
-            new_attrs[data_layout_key] = new_layout
-            new_attrs['out_layout'] = new_layout
-            new_attrs['kernel_layout'] = 'OIHW4o4i'
-            ic_block_factor = oc_block_factor = 4
-
-            # Store the same config for the altered operator (workload)
-            new_data = tvm.placeholder((N, CI // ic_block_factor, H, W, ic_block_factor),
-                                       dtype=data.dtype)
-            new_kernel = tvm.placeholder((CO // oc_block_factor, CI // ic_block_factor // groups,\
-                                         KH, KW, oc_block_factor, ic_block_factor),
-                                         dtype=kernel.dtype)
-            new_workload = autotvm.task.args_to_workload(
-                [new_data, new_kernel, strides, padding, dilation, groups, out_dtype],
-                group_conv2d_nchw
-            )
-            dispatch_ctx.update(target, new_workload, cfg)
-            return F.nn.conv2d(*copy_inputs, **new_attrs)
-
-    # do nothing for depthwise convolution
-    return None
+@autotvm.register_topi_schedule("conv2d_nchw_winograd_without_weight_transform.cuda")
+def schedule_conv2d_nchw_winograd_without_weight_transform(cfg, outs):
+    """TOPI schedule callback"""
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if 'conv2d_nchw_winograd' in op.tag:
+            schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=True)
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
index b46f284..0a6a71c 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, unused-argument
 """Compute definition for conv3d with cuda backend"""
 import tvm
 from tvm import autotvm
 from tvm.contrib import cudnn
 
 from .. import nn, generic
-from ..nn.util import get_pad_tuple3d
 from ..util import get_const_tuple, traverse_inline
+from .conv3d_direct import schedule_direct_conv3d_cuda
 
-from .conv3d_direct import schedule_direct_3d_cuda
 
-
-@autotvm.register_topi_compute(nn.conv3d, ['cuda', 'gpu'], ['direct'])
-def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', out_dtype='float32'):
-    """Conv3D operator for cuda backend.
+@autotvm.register_topi_compute("conv3d_ncdhw.cuda")
+def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
+    """Conv3D operator in NCDHW layout for cuda backend.
 
     Parameters
     ----------
@@ -45,17 +43,12 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o
     strides : int or a list/tuple of three ints
         stride size, or [stride_depth, stride_height, stride_width]
 
-    padding : int or a list/tuple of 3 or 6 ints
-        padding size, or
-        [pad_depth, pad_height, pad_width] for 3 ints, or
-        [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right] for 6 ints
+    padding : int or a list/tuple of three ints
+        padding size, or [pad_depth, pad_height, pad_width]
 
     dilation: int or a list/tuple of three ints
         dilation size, or [dilation_depth, dilation_height, dilation_width]
 
-    layout : str
-        layout of data
-
     out_dtype: str
         The output type. This is used for mixed precision.
 
@@ -64,52 +57,11 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o
     output : tvm.Tensor
         5-D with shape [batch, out_channel, out_depth, out_height, out_width]
     """
-    target = tvm.target.Target.current()
-
-    if "cudnn" in target.libs:
-        if layout == 'NCDHW':
-            tensor_format = 0 # CUDNN_TENSOR_NCHW
-            N, _, D, H, W = get_const_tuple(data.shape)
-        elif layout == 'NDHWC':
-            tensor_format = 1 # CUDNN_TENSOR_NHWC
-            N, D, H, W, _ = get_const_tuple(data.shape)
-        else:
-            raise ValueError("Unsupported layout %s in cudnn" % layout)
-        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
-
-        # handle dilation
-        stride_d, stride_h, stride_w = (strides, strides, strides) if isinstance(strides, int) \
-            else strides
-        if isinstance(padding, (list, tuple)) and len(padding) > 3:
-            raise ValueError("Cudnn doesn't support asymmetric padding.")
-        pf, pt, pl, pk, pb, pr = get_pad_tuple3d(padding, (KD, KH, KW))
-        dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation) if \
-            isinstance(dilation, int) else dilation
-
-        OD = (D + pf + pk - KD) // stride_d + 1
-        OH = (H + pt + pb - KH) // stride_h + 1
-        OW = (W + pl + pr - KW) // stride_w + 1
-        cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((KD - 1) * dilation_d + 1) *\
-                    ((KH - 1) * dilation_h + 1) * ((KW - 1) * dilation_w + 1))
-
-        return cudnn.conv_forward(data,
-                                  kernel,
-                                  [pf, pt, pl],  # cudnn padding pt, pl on both sides of input
-                                  [stride_d, stride_h, stride_w],
-                                  [dilation_d, dilation_h, dilation_w],
-                                  conv_mode=1,
-                                  tensor_format=tensor_format,
-                                  algo=-1,         # let CUDNN choose the best algo
-                                  conv_dtype=data.dtype)
-
-    if layout == 'NCDHW':
-        return nn.conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype)
-    raise ValueError("not support this layout {} yet".format(layout))
+    return nn.conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype)
 
 
-@autotvm.register_topi_schedule(generic.schedule_conv3d_ncdhw, ["cuda", "gpu"],
-                                ["direct"])
-def schedule_conv3d_ncdhw_cuda(cfg, outs):
+@autotvm.register_topi_schedule("conv3d_ncdhw.cuda")
+def schedule_conv3d_ncdhw(cfg, outs):
     """TOPI schedule callback of conv3d for cuda gpu
 
     Parameters
@@ -126,24 +78,49 @@ def schedule_conv3d_ncdhw_cuda(cfg, outs):
     s: Schedule
         The computation schedule for conv2d.
     """
-    target = tvm.target.Target.current()
-    if 'cudnn' in target.libs:
-        return generic.schedule_extern(outs)
-
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
         if op.tag == 'conv3d_ncdhw':
-            schedule_direct_3d_cuda(cfg, s, op.output(0))
+            schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NCDHW",
+                                        "conv3d_ncdhw.cuda")
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
-@autotvm.register_topi_schedule(generic.schedule_conv3d_ndhwc, ["cuda", "gpu"],
-                                ["direct"])
-def schedule_conv3d_ndhwc_cuda(cfg, outs):
+@autotvm.register_topi_compute("conv3d_ndhwc.cuda")
+def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
+    """Conv3d operator in NDHWC layout for cuda backend.
+
+    Parameters
+    ----------
+    Input : tvm.Tensor
+        5-D with shape [batch, in_depth, in_height, in_width, in_channel]
+
+    Filter : tvm.Tensor
+        5-D with shape [filter_depth, filter_height, filter_width, in_channel, num_filter]
+
+    stride : int or a list/tuple of three ints
+        Stride size, or [stride_depth, stride_height, stride_width]
+
+    padding : int or str
+        Padding size, or ['VALID', 'SAME']
+
+    dilation: int or a list/tuple of three ints
+        dilation size, or [dilation_depth, dilation_height, dilation_width]
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        5-D with shape [batch, out_depth, out_height, out_width, out_channel]
+    """
+    return nn.conv3d_ndhwc(data, kernel, strides, padding, dilation, out_dtype)
+
+
+@autotvm.register_topi_schedule("conv3d_ndhwc.cuda")
+def schedule_conv3d_ndhwc(cfg, outs):
     """TOPI schedule callback of conv3d for cuda gpu
 
     Parameters
@@ -160,16 +137,104 @@ def schedule_conv3d_ndhwc_cuda(cfg, outs):
     s: Schedule
         The computation schedule for conv2d.
     """
-    target = tvm.target.Target.current()
-    if 'cudnn' in target.libs:
-        return generic.schedule_extern(outs)
-
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
         if op.tag == 'conv3d_ndhwc':
-            schedule_direct_3d_cuda(cfg, s, op.output(0))
+            schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NDHWC",
+                                        "conv3d_ndhwc.cuda")
 
     traverse_inline(s, outs[0].op, _callback)
     return s
+
+
+@autotvm.register_topi_compute("conv3d_cudnn.cuda")
+def conv3d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCDHW',
+                 out_dtype='float32'):
+    """Conv3D operator for cuda backend.
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.Tensor
+        5-D with shape [batch, in_channel, in_depth, in_height, in_width]
+
+    kernel : tvm.Tensor
+        5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
+
+    strides : int or a list/tuple of three ints
+        stride size, or [stride_depth, stride_height, stride_width]
+
+    padding : int or a list/tuple of three ints
+        padding size, or [pad_depth, pad_height, pad_width]
+
+    dilation: int or a list/tuple of three ints
+        dilation size, or [dilation_depth, dilation_height, dilation_width]
+
+    layout : str
+        layout of data
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.Tensor
+        5-D with shape [batch, out_channel, out_depth, out_height, out_width]
+    """
+    if layout == 'NCDHW':
+        tensor_format = 0 # CUDNN_TENSOR_NCHW
+        N, _, D, H, W = get_const_tuple(data.shape)
+    elif layout == 'NDHWC':
+        tensor_format = 1 # CUDNN_TENSOR_NHWC
+        N, D, H, W, _ = get_const_tuple(data.shape)
+    else:
+        raise ValueError("Unsupported layout %s in cudnn" % layout)
+    CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+
+    # handle dilation
+    stride_d, stride_h, stride_w = (strides, strides, strides) if isinstance(strides, int) \
+        else strides
+    pad_d, pad_h, pad_w = (padding, padding, padding) if isinstance(padding, int) else padding
+    dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation) if \
+        isinstance(dilation, int) else dilation
+
+    OD = (D + 2 * pad_d - KD) // stride_d + 1
+    OH = (H + 2 * pad_h - KH) // stride_h + 1
+    OW = (W + 2 * pad_w - KW) // stride_w + 1
+    cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((KD - 1) * dilation_d + 1) * \
+                 ((KH - 1) * dilation_h + 1) * ((KW - 1) * dilation_w + 1))
+
+    return cudnn.conv_forward(data,
+                              kernel,
+                              [pad_d, pad_h, pad_w],
+                              [stride_d, stride_h, stride_w],
+                              [dilation_d, dilation_h, dilation_w],
+                              conv_mode=1,
+                              tensor_format=tensor_format,
+                              algo=-1,         # let CUDNN choose the best algo
+                              conv_dtype=dtype)
+
+
+@autotvm.register_topi_schedule("conv3d_cudnn.cuda")
+def schedule_conv3d_cudnn(_, outs):
+    """TOPI schedule callback of conv3d for cuda gpu
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    outs: Array of Tensor
+        The computation graph description of conv2d
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for conv2d.
+    """
+    return generic.schedule_extern(outs)
index ad48deb..fa6c878 100644 (file)
@@ -20,11 +20,16 @@ import tvm
 from tvm import autotvm
 from ..util import get_const_tuple
 
-def schedule_direct_3d_cuda(cfg, s, conv):
+def schedule_direct_conv3d_cuda(cfg, s, conv, layout, workload_name):
     """schedule optimized for batch size = 1"""
 
     ##### space definition begin #####
-    n, f, d, y, x = s[conv].op.axis
+    if layout == "NCDHW":
+        n, f, d, y, x = s[conv].op.axis
+    elif layout == "NDHWC":
+        n, d, y, x, f = s[conv].op.axis
+    else:
+        raise ValueError("not support this layout {} yet".format(layout))
     rc, rd, ry, rx = s[conv].op.reduce_axis
     cfg.define_split("tile_f", f, num_outputs=4)
     cfg.define_split("tile_d", d, num_outputs=4)
@@ -45,7 +50,7 @@ def schedule_direct_3d_cuda(cfg, s, conv):
     # fallback support
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.target_name, target.model, 'conv3d', 'direct')
+            target.target_name, target.model, workload_name)
         cfg.fallback_with_reference_log(ref_log)
     ##### space definition end #####
 
index 33a8c9a..bdec4e1 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name,unused-argument
 """Schedule template of deformable conv2d with cuda backend"""
 import tvm
 from tvm import autotvm
-from .. import nn, generic
+from .. import nn
 from ..util import traverse_inline
 
 
-autotvm.register_topi_compute(nn.deformable_conv2d_nchw, ["cuda", "gpu"], "direct",
-                              nn.deformable_conv2d_nchw.fdefault)
+@autotvm.register_topi_compute("deformable_conv2d_nchw.cuda")
+def deformable_conv2d_nchw(cfg, data, offset, kernel, strides, padding, dilation,
+                           deformable_groups, groups, out_dtype):
+    return nn.deformable_conv2d_nchw(data, offset, kernel, strides, padding, dilation,
+                                     deformable_groups, groups, out_dtype)
 
-
-@autotvm.register_topi_schedule(generic.schedule_deformable_conv2d_nchw, ["cuda", "gpu"], "direct")
-def schedule_deformable_conv2d_nchw_cuda(cfg, outs):
+@autotvm.register_topi_schedule("deformable_conv2d_nchw.cuda")
+def schedule_deformable_conv2d_nchw(cfg, outs):
     """TOPI schedule callback of deformable conv2d for cuda gpu
 
     Parameters
@@ -49,13 +51,13 @@ def schedule_deformable_conv2d_nchw_cuda(cfg, outs):
 
     def _callback(op):
         if op.tag == 'deformable_conv2d_nchw':
-            schedule_direct_cuda(cfg, s, op.output(0))
+            _schedule_direct_cuda(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
 
-def schedule_direct_cuda(cfg, s, conv):
+def _schedule_direct_cuda(cfg, s, conv):
     """Schedule template of deformable conv2d"""
     n, f, y, x = s[conv].op.axis
     rc, ry, rx = s[conv].op.reduce_axis
index 1a1af70..93797a4 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, unused-variable
+# pylint: disable=invalid-name, unused-argument
 """Schedule for dense operator"""
 from __future__ import absolute_import as _abs
 import logging
@@ -23,111 +23,60 @@ import tvm.autotvm as autotvm
 from tvm.autotvm.task.space import SplitEntity
 from tvm.contrib import cublas
 from .tensor_intrin import dp4a
-from ..nn.dense import dense, dense_default
+from .. import nn
 from .. import tag
 from .. import generic
 from ..util import traverse_inline, get_const_tuple
 
 logger = logging.getLogger('topi')
 
-
-@autotvm.register_topi_compute(dense, ["cuda", "gpu"], "direct")
-def dense_cuda(cfg, data, weight, bias=None, out_dtype=None):
-    """Dense operator for cuda backend.
-
-    Parameters
-    ----------
-    data : tvm.Tensor
-        2-D with shape [batch, in_dim]
-
-    weight : tvm.Tensor
-        2-D with shape [out_dim, in_dim]
-
-    bias : tvm.Tensor, optional
-        1-D with shape [out_dim]
-
-    Returns
-    -------
-    output : tvm.Tensor
-        2-D with shape [batch, out_dim]
-    """
-    # pylint: disable=unused-argument
+@autotvm.register_topi_compute("dense_cublas.cuda")
+def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
+    """Dense operator on CUDA with CUBLAS"""
     assert len(data.shape) == 2 and len(weight.shape) == 2, \
         "only support 2-dim dense"
     if bias is not None:
         assert len(bias.shape) == 1
     if out_dtype is None:
         out_dtype = data.dtype
+    assert out_dtype == data.dtype, "Mixed precision not supported."
     batch, in_dim = data.shape
     out_dim, _ = weight.shape
-    target = tvm.target.Target.current()
-    if "cublas" in target.libs:
-        matmul = cublas.matmul(data, weight, False, True, out_dtype)
-        if bias is not None:
-            matmul = tvm.compute((batch, out_dim), \
-                                 lambda i, j: matmul[i, j] + bias[j], \
-                                 tag=tag.BROADCAST)
-        return matmul
-    return dense_default(data, weight, bias, out_dtype)
-
-
-@autotvm.register_topi_schedule(generic.schedule_dense, ["cuda", "gpu"], "direct")
-def schedule_dense(cfg, outs):
-    """Schedule for dense operator.
-
-    Parameters
-    ----------
-    outs: Array of Tensor
-        The computation graph description of dense
-        in the format of an array of tensors.
-
-    Returns
-    -------
-    s: Schedule
-        The computation schedule for dense.
-    """
-    # pylint: disable=unused-argument
-    target = tvm.target.Target.current()
-
-    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    if target.target_name == "cuda" and "cublas" in target.libs:
-        return generic.schedule_extern(outs)
+    matmul = cublas.matmul(data, weight, False, True)
+    cfg.add_flop(batch * in_dim * out_dim * 2)
+    if bias is not None:
+        matmul = tvm.compute((batch, out_dim),
+                             lambda i, j: matmul[i, j] + bias[j],
+                             tag=tag.BROADCAST)
+    return matmul
 
-    s = tvm.create_schedule([x.op for x in outs])
 
-    def _schedule(C):
-        A, _ = C.op.input_tensors
-        batch, _ = get_const_tuple(A.shape)
-        if batch < 32:
-            return schedule_dense_small_batch(cfg, s, C)
-        return schedule_dense_large_batch(cfg, s, C)
-
-    scheduled_ops = []
-
-    def traverse(OP):
-        """Internal traverse function"""
-        # inline all one-to-one-mapping operators except the last stage (output)
-        if tag.is_broadcast(OP.tag):
-            if OP not in s.outputs:
-                s[OP].compute_inline()
-            for tensor in OP.input_tensors:
-                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
-                    traverse(tensor.op)
-        # schedule dense
-        elif OP.tag == 'dense':
-            Dense = OP.output(0)
-            _schedule(Dense)
-        else:
-            raise RuntimeError("Unsupported operator: %s" % OP.tag)
+@autotvm.register_topi_schedule("dense_cublas.cuda")
+def schedule_dense_cublas(_, outs):
+    """Schedule dense operator using CUBLAS"""
+    return generic.schedule_extern(outs)
 
-        scheduled_ops.append(OP)
 
-    traverse(outs[0].op)
-    return s
+@autotvm.register_topi_compute("dense_small_batch.cuda")
+def dense_small_batch(cfg, data, weight, bias=None, out_dtype=None):
+    """Dense operator on CUDA"""
+    return nn.dense(data, weight, bias, out_dtype)
 
 
-def schedule_dense_small_batch(cfg, s, C):
+@autotvm.register_topi_schedule("dense_small_batch.cuda")
+def schedule_dense_small_batch(cfg, outs):
     """Schedule float32/64 dense with small batch size"""
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == 'dense':
+            _schedule_dense_small_batch(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+def _schedule_dense_small_batch(cfg, s, C):
     A, _ = C.op.input_tensors
     _, in_dim = get_const_tuple(A.shape)
     cfg.define_split('tile_k', in_dim, num_outputs=2)
@@ -152,7 +101,28 @@ def schedule_dense_small_batch(cfg, s, C):
     s[C].set_store_predicate(thread_x.var.equal(0))
     s[Out].set_store_predicate(thread_x.var.equal(0))
 
-def schedule_dense_large_batch(cfg, s, C):
+
+@autotvm.register_topi_compute("dense_large_batch.cuda")
+def dense_large_batch(cfg, data, weight, bias=None, out_dtype=None):
+    """Dense operator on CUDA"""
+    return nn.dense(data, weight, bias, out_dtype)
+
+
+@autotvm.register_topi_schedule("dense_large_batch.cuda")
+def schedule_dense_large_batch(cfg, outs):
+    """Schedule float32/64 dense with large batch size"""
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == 'dense':
+            _schedule_dense_large_batch(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def _schedule_dense_large_batch(cfg, s, C):
     """Schedule float32/64 dense with large batch size"""
     A, B = C.op.input_tensors
     batch, in_dim = get_const_tuple(A.shape)
@@ -250,7 +220,8 @@ def schedule_dense_large_batch(cfg, s, C):
     s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
     s[BB].double_buffer()
 
-@autotvm.register_topi_compute(dense, ['cuda'], ['int8'])
+
+@autotvm.register_topi_compute("dense_int8.cuda")
 def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
     """Dense operator for int8 on CUDA"""
     if out_dtype is None:
@@ -258,16 +229,6 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
 
     batch, in_dim = get_const_tuple(data.shape)
     out_dim, _ = get_const_tuple(weight.shape)
-
-    target = tvm.target.Target.current()
-    if "cublas" in target.libs:
-        matmul = cublas.matmul(data, weight, False, True, out_dtype)
-        if bias is not None:
-            matmul = tvm.compute((batch, out_dim), \
-                                 lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), \
-                                 tag=tag.BROADCAST)
-        return matmul
-
     k = tvm.reduce_axis((0, in_dim), name='k')
 
     matmul = tvm.compute((batch, out_dim),
@@ -286,15 +247,11 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
     return matmul
 
 
-@autotvm.register_topi_schedule(generic.schedule_dense, ['cuda', 'gpu'], ['int8'])
+@autotvm.register_topi_schedule("dense_int8.cuda")
 def schedule_dense_int8(cfg, outs):
     """Dense schedule for int8 on CUDA"""
-    s = tvm.create_schedule([x.op for x in outs])
-    target = tvm.target.Target.current()
-
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    if "cublas" in target.libs:
-        return generic.schedule_extern(outs)
+    s = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
         if "dense_int8" in op.tag:
index 05e1117..062f95f 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, unused-argument
 """Schedule for depthwise_conv2d with auto fusion"""
 import tvm
 from tvm import autotvm
 from ..util import traverse_inline
 from .. import tag
-from .. import generic, nn
+from .. import nn
 
 # register original implementation of depthwise_conv2d_nchw since we don't need to change this part
-autotvm.register_topi_compute(nn.depthwise_conv2d_nchw, ['cuda', 'gpu'], 'direct',
-                              nn.depthwise_conv2d_nchw.fdefault)
+@autotvm.register_topi_compute("depthwise_conv2d_nchw.cuda")
+def depthwise_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    """Compute depthwise_conv2d with NCHW layout."""
+    return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
 
-@autotvm.register_topi_schedule(generic.schedule_depthwise_conv2d_nchw, ['cuda', 'gpu'], 'direct')
-def schedule_depthwise_conv2d_nchw_cuda(cfg, outs):
+@autotvm.register_topi_schedule("depthwise_conv2d_nchw.cuda")
+def schedule_depthwise_conv2d_nchw(cfg, outs):
     """Schedule for depthwise_conv2d nchw forward.
 
     Parameters
@@ -66,7 +68,7 @@ def schedule_depthwise_conv2d_nchw_cuda(cfg, outs):
             # fallback support
             if cfg.is_fallback:
                 ref_log = autotvm.tophub.load_reference_log(
-                    target.target_name, target.model, 'depthwise_conv2d_nchw', 'direct')
+                    target.target_name, target.model, 'depthwise_conv2d_nchw.cuda')
                 cfg.fallback_with_reference_log(ref_log)
                 # TODO(lmzheng): A bug here, set unroll_explicit to False as workaround
                 cfg['unroll_explicit'].val = 0
@@ -131,7 +133,6 @@ def schedule_depthwise_conv2d_nchw_cuda(cfg, outs):
     traverse_inline(s, outs[0].op, _callback)
     return s
 
-@generic.schedule_depthwise_conv2d_nhwc.register(["cuda", "gpu"])
 def schedule_depthwise_conv2d_nhwc(outs):
     """Schedule for depthwise_conv2d nhwc forward.
 
index 54e8427..5abf298 100644 (file)
@@ -24,15 +24,163 @@ from .tensor_intrin import dp4a
 from ..nn.pad import pad
 from ..nn.util import get_pad_tuple
 from ..util import traverse_inline, get_const_tuple, get_const_int
-from .. import nn, generic
+from .. import nn
 
 
-autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], 'direct',
-                              nn.group_conv2d_nchw.fdefault)
+@autotvm.register_topi_compute("group_conv2d_nchw.cuda")
+def group_conv2d_nchw(_, data, kernel, stride, padding, dilation, groups,
+                      out_dtype='float32'):
+    return nn.group_conv2d_nchw(data, kernel, stride, padding, dilation, groups, out_dtype)
 
-@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['int8'])
-def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
-                           out_dtype='float32'):
+
+@autotvm.register_topi_schedule("group_conv2d_nchw.cuda")
+def schedule_group_conv2d_nchw(cfg, outs):
+    """TOPI schedule callback of group conv2d for cuda gpu
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    outs: Array of Tensor
+        The computation graph description of conv2d
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for group conv2d.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == "group_conv2d_nchw":
+            _schedule_group_conv2d_nchw_direct(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def _schedule_group_conv2d_nchw_direct(cfg, s, conv):
+    """Schedule group conv2d NCHW direct template"""
+    workload = conv.op.attrs["workload"]
+    groups = get_const_int(workload[6])
+    num_filters = get_const_int(conv.shape[1])
+
+    ##### space definition begin #####
+    n, f, y, x = s[conv].op.axis
+    rc, ry, rx = s[conv].op.reduce_axis
+    cfg.define_split("tile_n", n, num_outputs=4)
+    cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2)
+    cfg.define_split("tile_f", cfg.axis(num_filters // groups), num_outputs=4)
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+    cfg.define_split("tile_rc", rc, num_outputs=2)
+    cfg.define_split("tile_ry", ry, num_outputs=2)
+    cfg.define_split("tile_rx", rx, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+
+    target = tvm.target.Target.current()
+    if target.target_name in ['nvptx', 'rocm']:
+        cfg.define_knob("unroll_explicit", [1])
+    else:
+        cfg.define_knob("unroll_explicit", [0, 1])
+
+    pad_data, kernel = s[conv].op.input_tensors
+
+    s[pad_data].compute_inline()
+
+    if conv.op in s.outputs:
+        output = conv
+        OL = s.cache_write(conv, 'local')
+    else:
+        output = s.outputs[0].output(0)
+        s[conv].set_scope('local')
+        OL = conv
+
+    # create cache stage
+    AA = s.cache_read(pad_data, 'shared', [OL])
+    WW = s.cache_read(kernel, 'shared', [OL])
+
+    # tile and bind spatial axes
+    n, f, y, x = s[output].op.axis
+    kernel_scope, n = s[output].split(n, nparts=1)
+
+    g, f = s[output].split(f, nparts=groups)
+    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
+    bg, vg = cfg["tile_g"].apply(s, output, g)
+    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+
+    s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
+    s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
+    s[output].bind(s[output].fuse(bg, bf), tvm.thread_axis("blockIdx.y"))
+    s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
+    s[output].bind(vn, tvm.thread_axis("vthread"))
+    s[output].bind(vg, tvm.thread_axis("vthread"))
+    s[output].bind(vf, tvm.thread_axis("vthread"))
+    s[output].bind(vy, tvm.thread_axis("vthread"))
+    s[output].bind(vx, tvm.thread_axis("vthread"))
+
+    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
+    if cfg["fuse_yx"].val:
+        s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
+        s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
+        tyx = s[output].fuse(ty, tx)
+        s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
+        s[OL].compute_at(s[output], tyx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2]
+        n_ty = cfg["tile_f"].size[2]
+        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
+    else:
+        s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
+        s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
+        s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
+        s[OL].compute_at(s[output], tx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
+        n_ty = cfg["tile_y"].size[2]
+        n_tx = cfg["tile_x"].size[2]
+
+    # tile reduction axes
+    n, f, y, x = s[OL].op.axis
+    rc, ry, rx = s[OL].op.reduce_axis
+    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    ryo, ryi = cfg['tile_rx'].apply(s, OL, ry)
+    rxo, rxi = cfg['tile_ry'].apply(s, OL, rx)
+    s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
+
+    s[AA].compute_at(s[OL], rxo)
+    s[WW].compute_at(s[OL], rxo)
+
+    # cooperative fetching
+    for load in [AA, WW]:
+        n, f, y, x = s[load].op.axis
+        fused = s[load].fuse(n, f, y, x)
+        fused, tx = s[load].split(fused, factor=n_tx)
+        fused, ty = s[load].split(fused, factor=n_ty)
+        fused, tz = s[load].split(fused, factor=n_tz)
+        s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
+        s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
+        s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
+
+    # unroll
+    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
+    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+
+    N, CO, OH, OW = get_const_tuple(output.shape)
+    _, CI_div_groups, KH, KW = get_const_tuple(kernel.shape)
+    cfg.add_flop(2 * N * OH * OW * CO * CI_div_groups * KH * KW)
+
+
+@autotvm.register_topi_compute("group_conv2d_NCHWc_int8.cuda")
+def group_conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, groups,
+                            out_dtype='float32'):
     """Group convolution operator for 'group_conv2d_NCHWc_int8'.
 
     Parameters
@@ -154,30 +302,58 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
     #
     # Compared with a normal convolution, group convolution only sums
     # input channels from the group that an output channel resides in.
-    conv = tvm.compute(oshape, lambda n, occ, oh, ow, ocb:
-                       tvm.sum(pad_data[n, occ//(oc_chunk//groups)*(ic_chunk//groups)+icc,
-                                        oh*stride_h+kh*dilation_h, ow*stride_w+kw*dilation_w, icb]
-                               .astype('int32') *
-                               packed_kernel[occ, icc,
-                                             kh, kw, ocb, icb]
-                               .astype('int32'),
-                               axis=[icc, kh, kw, icb]))
+    conv = tvm.compute(
+        oshape, lambda n, occ, oh, ow, ocb:
+        tvm.sum(pad_data[n, occ//(oc_chunk//groups)*(ic_chunk//groups)+icc,
+                         oh*stride_h+kh*dilation_h, ow*stride_w+kw*dilation_w, icb]
+                .astype('int32') *
+                packed_kernel[occ, icc, kh, kw, ocb, icb].astype('int32'),
+                axis=[icc, kh, kw, icb]))
 
     # Type conversion
     output = tvm.compute(oshape, lambda *index: conv(*index).astype(out_dtype),
                          tag='group_conv2d_NCHWc_int8')
 
     num_flop = batch * oc_chunk * oc_block * out_height * out_width * \
-        ic_chunk * ic_block * kernel_h * kernel_w * 2 // groups
+               ic_chunk * ic_block * kernel_h * kernel_w * 2 // groups
     cfg.add_flop(num_flop)
 
     return output
 
 
+@autotvm.register_topi_schedule("group_conv2d_NCHWc_int8.cuda")
+def schedule_group_conv2d_NCHWc_int8(cfg, outs):
+    """TOPI schedule callback of group conv2d for cuda gpu
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    outs: Array of Tensor
+        The computation graph description of conv2d
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for group conv2d.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == "group_conv2d_NCHWc_int8":
+            _schedule_group_conv2d_NCHWc_int8(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
 _dp4a = dp4a('shared', 'shared', 'local')
 
 
-def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
+def _schedule_group_conv2d_NCHWc_int8(cfg, s, output):
     """Schedule group conv2d int8 NCHWc template"""
     workload = output.op.attrs["workload"]
     groups = get_const_int(workload[6])
@@ -198,7 +374,7 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
         s[packed_kernel].pragma(
             s[packed_kernel].op.axis[0], "debug_skip_region")
     else:
-        if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\
+        if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and \
                 packed_kernel.name == 'packed_kernel':
             # data and kernel are not pre-computed, schedule layout transform here
             schedule_injective_from_existing(s, packed_data)
@@ -319,151 +495,3 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
     s[output].pragma(kernel_scope, 'unroll_explicit', False)
 
     return s
-
-
-def schedule_group_conv2d_nchw_direct(cfg, s, conv):
-    """Schedule group conv2d NCHW direct template"""
-    workload = conv.op.attrs["workload"]
-    groups = get_const_int(workload[6])
-    num_filters = get_const_int(conv.shape[1])
-
-    ##### space definition begin #####
-    n, f, y, x = s[conv].op.axis
-    rc, ry, rx = s[conv].op.reduce_axis
-    cfg.define_split("tile_n", n, num_outputs=4)
-    cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2)
-    cfg.define_split("tile_f", cfg.axis(num_filters // groups), num_outputs=4)
-    cfg.define_split("tile_y", y, num_outputs=4)
-    cfg.define_split("tile_x", x, num_outputs=4)
-    cfg.define_split("tile_rc", rc, num_outputs=2)
-    cfg.define_split("tile_ry", ry, num_outputs=2)
-    cfg.define_split("tile_rx", rx, num_outputs=2)
-    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
-
-    target = tvm.target.Target.current()
-    if target.target_name in ['nvptx', 'rocm']:
-        cfg.define_knob("unroll_explicit", [1])
-    else:
-        cfg.define_knob("unroll_explicit", [0, 1])
-
-    pad_data, kernel = s[conv].op.input_tensors
-
-    s[pad_data].compute_inline()
-
-    if conv.op in s.outputs:
-        output = conv
-        OL = s.cache_write(conv, 'local')
-    else:
-        output = s.outputs[0].output(0)
-        s[conv].set_scope('local')
-        OL = conv
-
-    # create cache stage
-    AA = s.cache_read(pad_data, 'shared', [OL])
-    WW = s.cache_read(kernel, 'shared', [OL])
-
-    # tile and bind spatial axes
-    n, f, y, x = s[output].op.axis
-    kernel_scope, n = s[output].split(n, nparts=1)
-
-    g, f = s[output].split(f, nparts=groups)
-    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
-    bg, vg = cfg["tile_g"].apply(s, output, g)
-    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
-    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
-    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
-
-    s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
-    s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
-    s[output].bind(s[output].fuse(bg, bf), tvm.thread_axis("blockIdx.y"))
-    s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
-    s[output].bind(vn, tvm.thread_axis("vthread"))
-    s[output].bind(vg, tvm.thread_axis("vthread"))
-    s[output].bind(vf, tvm.thread_axis("vthread"))
-    s[output].bind(vy, tvm.thread_axis("vthread"))
-    s[output].bind(vx, tvm.thread_axis("vthread"))
-
-    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
-    if cfg["fuse_yx"].val:
-        s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
-        s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
-        tyx = s[output].fuse(ty, tx)
-        s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
-        s[OL].compute_at(s[output], tyx)
-
-        # number of threads
-        n_tz = cfg["tile_n"].size[2]
-        n_ty = cfg["tile_f"].size[2]
-        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
-    else:
-        s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
-        s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
-        s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
-        s[OL].compute_at(s[output], tx)
-
-        # number of threads
-        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
-        n_ty = cfg["tile_y"].size[2]
-        n_tx = cfg["tile_x"].size[2]
-
-    # tile reduction axes
-    n, f, y, x = s[OL].op.axis
-    rc, ry, rx = s[OL].op.reduce_axis
-    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
-    ryo, ryi = cfg['tile_rx'].apply(s, OL, ry)
-    rxo, rxi = cfg['tile_ry'].apply(s, OL, rx)
-    s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
-
-    s[AA].compute_at(s[OL], rxo)
-    s[WW].compute_at(s[OL], rxo)
-
-    # cooperative fetching
-    for load in [AA, WW]:
-        n, f, y, x = s[load].op.axis
-        fused = s[load].fuse(n, f, y, x)
-        fused, tx = s[load].split(fused, factor=n_tx)
-        fused, ty = s[load].split(fused, factor=n_ty)
-        fused, tz = s[load].split(fused, factor=n_tz)
-        s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
-        s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
-        s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
-
-    # unroll
-    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
-    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
-
-    N, CO, OH, OW = get_const_tuple(output.shape)
-    _, CI_div_groups, KH, KW = get_const_tuple(kernel.shape)
-    cfg.add_flop(2 * N * OH * OW * CO * CI_div_groups * KH * KW)
-
-
-@autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw,
-                                ["cuda", "gpu"], ["int8", "direct"])
-def schedule_conv2d_nchw_cuda(cfg, outs):
-    """TOPI schedule callback of group conv2d for cuda gpu
-
-    Parameters
-    ----------
-    cfg: ConfigEntity
-        The config for this template
-
-    outs: Array of Tensor
-        The computation graph description of conv2d
-        in the format of an array of tensors.
-
-    Returns
-    -------
-    s: Schedule
-        The computation schedule for group conv2d.
-    """
-    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    s = tvm.create_schedule([x.op for x in outs])
-
-    def _callback(op):
-        if op.tag == "group_conv2d_NCHWc_int8":
-            schedule_group_conv2d_NCHWc_int8(cfg, s, op.output(0))
-        if op.tag == "group_conv2d_nchw":
-            schedule_group_conv2d_nchw_direct(cfg, s, op.output(0))
-
-    traverse_inline(s, outs[0].op, _callback)
-    return s
index eb7019b..1690407 100644 (file)
 # pylint: disable=invalid-name, unused-variable,
 """Schedule for composition of injective operator"""
 import tvm
-from .. import generic, util
-from ..util import is_empty_shape
+from .. import util
 
-@generic.schedule_injective_from_existing.register(["cuda", "gpu"])
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -67,7 +65,6 @@ def schedule_injective_from_existing(sch, out):
 
     return sch
 
-@generic.schedule_injective.register(["cuda", "gpu"])
 def schedule_injective(outs):
     """Schedule for injective op.
 
@@ -87,7 +84,7 @@ def schedule_injective(outs):
 
     tvm.schedule.AutoInlineInjective(s)
     for out in outs:
-        if not is_empty_shape(out.shape):
+        if not util.is_empty_shape(out.shape):
             schedule_injective_from_existing(s, out)
     return s
 
index 5485859..27a5272 100644 (file)
@@ -22,7 +22,6 @@ import tvm
 
 from tvm import api
 from tvm.intrin import if_then_else
-from topi.vision import non_max_suppression, get_valid_counts
 from .sort import argsort
 from .. import tag
 
@@ -238,8 +237,7 @@ def out_rewrite(data, flag, prefix_sum, valid_count, out):
     return ib.get()
 
 
-@get_valid_counts.register(["cuda", "gpu"])
-def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
+def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
     """Get valid count of bounding boxes given a score threshold.
     Also moves valid boxes to the top of input data.
 
@@ -585,11 +583,10 @@ def invalid_to_bottom_ir(data, flag, idx, out):
     return ib.get()
 
 
-@non_max_suppression.register(["cuda", "gpu"])
-def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
-                            iou_threshold=0.5, force_suppress=False, top_k=-1,
-                            coord_start=2, score_index=1, id_index=0,
-                            return_indices=True, invalid_to_bottom=False):
+def non_max_suppression(data, valid_count, max_output_size=-1,
+                        iou_threshold=0.5, force_suppress=False, top_k=-1,
+                        coord_start=2, score_index=1, id_index=0,
+                        return_indices=True, invalid_to_bottom=False):
     """Non-maximum suppression operator for object detection.
 
     Parameters
index 327afa8..4460f7b 100644 (file)
 """scheduler functions for cuda backend"""
 from __future__ import absolute_import as _abs
 
-import tvm
-from .. import generic
 from .. import cpp
 
-@generic.schedule_lrn.register(["cuda"])
 def schedule_lrn(outs):
     """Schedule for LRN
 
@@ -37,6 +34,4 @@ def schedule_lrn(outs):
     sch: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.Target.current(allow_none=False)
-    cpp_target = cpp.TEST_create_target(target.target_name)
-    return cpp.cuda.schedule_lrn(cpp_target, outs)
+    return cpp.cuda.schedule_lrn(outs)
index 2bf1e6b..2bebd39 100644 (file)
 """Schedule for pooling operators"""
 import tvm
 from .. import tag
-from .. import generic
 from ..util import traverse_inline
 
 
-
-@generic.schedule_adaptive_pool.register(["cuda", "gpu"])
 def schedule_adaptive_pool(outs):
     """Schedule for adaptive_pool.
 
@@ -89,7 +86,6 @@ def schedule_adaptive_pool(outs):
     return s
 
 
-@generic.schedule_pool.register(["cuda", "gpu"])
 def schedule_pool(outs, layout):
     """Schedule for pool.
 
@@ -153,8 +149,7 @@ def schedule_pool(outs, layout):
     return s
 
 
-@generic.schedule_pool_grad.register(['cuda', 'gpu'])
-def schedule_pool_grad_cuda(outs):
+def schedule_pool_grad(outs):
     """Schedule for pool_grad on CUDA
 
     Parameters
index 42b34f0..da55b07 100644 (file)
@@ -17,4 +17,4 @@
 
 # pylint: disable=wildcard-import
 """Faster R-CNN and Mask R-CNN operators"""
-from .proposal import *
+from .proposal import proposal
index 4344226..489c354 100644 (file)
@@ -18,7 +18,7 @@
 """Proposal operator"""
 import math
 import tvm
-from ...vision.rcnn import proposal, generate_anchor, reg_bbox, reg_iou
+from ...vision.rcnn import generate_anchor, reg_bbox, reg_iou
 from ...util import get_const_tuple, get_const_int
 
 
@@ -308,9 +308,8 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
     return body
 
 
-@proposal.register("cuda")
-def proposal_cuda(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
-                  rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
+def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
+             rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
     """Proposal operator.
 
     Parameters
index 69c685c..0b9d588 100644 (file)
@@ -19,7 +19,6 @@
 from __future__ import absolute_import as _abs
 import tvm
 from .. import tag
-from .. import generic
 from .injective import schedule_injective_from_existing
 
 def _schedule_reduce(op, sch, is_idx_reduce=False):
@@ -89,7 +88,6 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
     return sch
 
 
-@generic.schedule_reduce.register(["cuda", "gpu"])
 def schedule_reduce(outs):
     """Schedule for inject->reduce->bcast ops.
 
index 26a1baf..afd11ea 100644 (file)
 # pylint: disable=invalid-name, unused-variable, trailing-whitespace
 """Schedule for softmax operator"""
 import tvm
-from .. import generic
 from .injective import schedule_injective_from_existing
 
-@generic.schedule_softmax.register(["cuda", "gpu"])
+
 def schedule_softmax(outs):
     """Schedule for softmax op.
 
index b32cce7..88ca9d8 100644 (file)
 import tvm
 
 from tvm import api
-from ..sort import argsort, topk
+from .injective import schedule_injective_from_existing
 from ..math import identity
 from ..transform import strided_slice
-from .. import generic
 from .. import tag
 
 def _schedule_sort(outs):
@@ -42,8 +41,7 @@ def _schedule_sort(outs):
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
     scheduled_ops = []
-    # pylint: disable=import-outside-toplevel
-    from .injective import schedule_injective_from_existing
+
     def traverse(op):
         if tag.is_injective(op.tag):
             schedule_injective_from_existing(s, op.output(0))
@@ -239,8 +237,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
 
     return ib.get()
 
-@argsort.register(["cuda", "gpu"])
-def argsort_gpu(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
+def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array of indicies
     having same shape as an input array that index data in sorted order.
 
@@ -294,7 +291,6 @@ def argsort_gpu(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
                          tag="argsort_gpu")[1]
     return out
 
-@generic.schedule_argsort.register(["cuda", "gpu"])
 def schedule_argsort(outs):
     """Schedule for argsort operator.
 
@@ -311,8 +307,7 @@ def schedule_argsort(outs):
     """
     return _schedule_sort(outs)
 
-@topk.register(["cuda", "gpu"])
-def topk_gpu(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
+def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
     """Get the top k elements in an input tensor along the given axis.
 
     Parameters
@@ -389,7 +384,6 @@ def topk_gpu(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"
     return output
 
 
-@generic.schedule_topk.register(["cuda", "gpu"])
 def schedule_topk(outs):
     """Schedule for argsort operator.
 
index 10ba7a1..0b3f50b 100644 (file)
@@ -25,9 +25,6 @@ from tvm.intrin import if_then_else, exp
 
 import topi
 
-from topi.vision.ssd import multibox_prior
-from topi.vision.ssd import multibox_detection
-from topi.vision.ssd import multibox_transform_loc
 from ..nms import non_max_suppression
 
 
@@ -112,9 +109,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
     return body
 
 
-@multibox_prior.register(["cuda", "gpu"])
-def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1),
-                       offsets=(0.5, 0.5), clip=False):
+def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1),
+                   offsets=(0.5, 0.5), clip=False):
     """Generate prior(anchor) boxes from data, sizes and ratios.
 
     Parameters
@@ -346,9 +342,8 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
     return ib.get()
 
 
-@multibox_transform_loc.register(["cuda", "gpu"])
-def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \
-                               threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)):
+def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, \
+                           threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)):
     """Location transformation for multibox detection
 
     Parameters
@@ -426,9 +421,8 @@ def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \
     return [out_loc, valid_count]
 
 
-@multibox_detection.register(["cuda", "gpu"])
-def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5,
-                           force_suppress=False, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=-1):
+def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5,
+                       force_suppress=False, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=-1):
     """Convert multibox detection predictions.
 
     Parameters
index d456aad..8666c22 100644 (file)
 """Schedule for vision operators"""
 from __future__ import absolute_import as _abs
 import tvm
-from .. import generic
 from .. import cpp
 from .. import tag
 from .pooling import schedule_pool
+from .injective import schedule_injective_from_existing
 
 def _default_schedule(outs):
     """Default schedule for gpu."""
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
     scheduled_ops = []
-    from .injective import schedule_injective_from_existing
     def traverse(op):
         if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']:
             schedule_injective_from_existing(s, op.output(0))
@@ -39,7 +38,6 @@ def _default_schedule(outs):
     traverse(outs[0].op)
     return s
 
-@generic.schedule_reorg.register(["cuda", "gpu"])
 def schedule_reorg(outs):
     """Schedule for reorg operator.
     Parameters
@@ -57,7 +55,6 @@ def schedule_reorg(outs):
     cpp_target = cpp.TEST_create_target(target.target_name)
     return cpp.cuda.schedule_injective(cpp_target, outs)
 
-@generic.schedule_nms.register(["cuda", "gpu"])
 def schedule_nms(outs):
     """Schedule for non-maximum suppression
 
@@ -74,7 +71,6 @@ def schedule_nms(outs):
     """
     return _default_schedule(outs)
 
-@generic.schedule_multibox_prior.register(["cuda", "gpu"])
 def schedule_multibox_prior(outs):
     """Schedule for multibox_prior operator.
 
@@ -91,7 +87,6 @@ def schedule_multibox_prior(outs):
     """
     return _default_schedule(outs)
 
-@generic.schedule_multibox_transform_loc.register(["cuda", "gpu"])
 def schedule_multibox_transform_loc(outs):
     """Schedule for multibox_transform_loc
 
@@ -109,7 +104,6 @@ def schedule_multibox_transform_loc(outs):
     """
     return _default_schedule(outs)
 
-@generic.schedule_multibox_detection.register(["cuda", "gpu"])
 def schedule_multibox_detection(outs):
     """Schedule for multibox_detection operator.
 
@@ -126,15 +120,12 @@ def schedule_multibox_detection(outs):
     """
     return _default_schedule(outs)
 
-@generic.schedule_roi_align.register(["cuda", "gpu"])
 def schedule_roi_align(outs):
     return schedule_pool(outs, 'NCHW')
 
-@generic.schedule_roi_pool.register(["cuda", "gpu"])
 def schedule_roi_pool(outs):
     return schedule_pool(outs, 'NCHW')
 
-@generic.schedule_proposal.register(["cuda", "gpu"])
 def schedule_proposal(outs):
     """Schedule for proposal operator.
 
@@ -151,7 +142,6 @@ def schedule_proposal(outs):
     """
     return _default_schedule(outs)
 
-@generic.schedule_get_valid_counts.register(["cuda", "gpu"])
 def schedule_get_valid_counts(outs):
     """Schedule for get_valid_counts operator.
 
index 332c2fd..08bb06c 100644 (file)
@@ -19,6 +19,7 @@
 """Generic convolution schedules"""
 from __future__ import absolute_import as _abs
 import tvm
+from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 from ..util import get_const_tuple
 
@@ -109,7 +110,8 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
     raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
 
 
-def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last, int32_lanes=16, intrin=None):
+def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
+                                        last, int32_lanes=16, intrin=None):
     """
     Defines the schedule for INT8 for Intel and ARM machines
     Uses the Intel/ARM intrinsics to use INT8 operations
@@ -117,14 +119,39 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last, int32_lane
     lower-numerical-precision-deep-learning-inference-and-training
     """
     reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
-    _, _, _, _, ic_bn = get_const_tuple(data.shape)
+    _, _, _, _, ic_bn = get_const_tuple(data_vec.shape)
     _, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
 
-    A = data
-    if isinstance(s[A].op, tvm.tensor.ComputeOp):
-        batch, ic_chunk, ih, iw, _ = s[A].op.axis
-        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
-        s[A].parallel(parallel_axis)
+    # schedule pad
+    if isinstance(s[data_vec].op, tvm.tensor.ComputeOp) \
+            and "pad" in data_vec.op.tag:
+        batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
+        parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
+        s[data_vec].parallel(parallel_axis)
+        data_vec = data_vec.op.input_tensors[0]
+
+    if autotvm.GLOBAL_SCOPE.in_tuning:
+        # only in autotuning, input data of conv2d_NCHWc will be 4-D.
+        # skip this part during tuning to make records accurate.
+        # this part will be folded during Relay fold_constant pass.
+        s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region")
+        s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region")
+    elif isinstance(kernel_vec.op, tvm.tensor.ComputeOp) and \
+            kernel_vec.name == 'kernel_vec':
+        # data and kernel are not pre-computed, schedule layout transform here.
+        # this should only be used by x86 conv2d_nchw, which is for
+        # testing purpose.
+        batch, ic_chunk, ih, ic_block, iw = s[data_vec].op.axis
+        parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
+        s[data_vec].parallel(parallel_axis)
+
+        oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
+        s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
+        oc_bn = cfg["tile_oc"].size[-1]
+        if oc_bn > 1:
+            s[kernel_vec].vectorize(oc_block)
+        parallel_axis = s[kernel_vec].fuse(oc_chunk, oh)
+        s[kernel_vec].parallel(parallel_axis)
 
     # schedule 5-D NCHW[x]c conv
     C, O = conv_out, last
@@ -173,7 +200,8 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last, int32_lane
 
     return s
 
-def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last, int32_lanes=16, intrin=None):
+def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
+                                     last, int32_lanes=16, intrin=None):
     """
     Defines the 1x1 conv schedule for INT8 for Intel and ARM machines
     Uses the Intel/ARM intrinsics to use INT8 operations
@@ -181,15 +209,39 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last, int32_lanes=1
     lower-numerical-precision-deep-learning-inference-and-training
     """
     oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
-    _, _, _, _, ic_bn = get_const_tuple(data.shape)
+    _, _, _, _, ic_bn = get_const_tuple(data_vec.shape)
     _, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
 
-    # schedule data
-    A = data
-    if isinstance(s[A].op, tvm.tensor.ComputeOp):
-        batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
-        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
-        s[A].parallel(parallel_axis)
+    # schedule pad
+    if isinstance(s[data_vec].op, tvm.tensor.ComputeOp) \
+            and "pad" in data_vec.op.tag:
+        batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
+        parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
+        s[data_vec].parallel(parallel_axis)
+        data_vec = data_vec.op.input_tensors[0]
+
+    if autotvm.GLOBAL_SCOPE.in_tuning:
+        # only in autotuning, input data of conv2d_NCHWc will be 4-D.
+        # skip this part during tuning to make records accurate.
+        # this part will be folded during Relay fold_constant pass.
+        s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region")
+        s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region")
+    elif isinstance(kernel_vec.op, tvm.tensor.ComputeOp) and \
+            kernel_vec.name == 'kernel_vec':
+        # data and kernel are not pre-computed, schedule layout transform here.
+        # this should only be used by x86 conv2d_nchw, which is for
+        # testing purpose.
+        batch, ic_chunk, ih, ic_block, iw = s[data_vec].op.axis
+        parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
+        s[data_vec].parallel(parallel_axis)
+
+        oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
+        s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
+        oc_bn = cfg["tile_oc"].size[-1]
+        if oc_bn > 1:
+            s[kernel_vec].vectorize(oc_block)
+        parallel_axis = s[kernel_vec].fuse(oc_chunk, oh)
+        s[kernel_vec].parallel(parallel_axis)
 
     C, O = conv_out, last
     CC = s.cache_write(C, 'global')
index e895385..977c537 100644 (file)
@@ -21,7 +21,6 @@ from __future__ import absolute_import as _abs
 import tvm
 from .. import cpp
 
-@tvm.target.generic_func
 def schedule_extern(outs):
     """Schedule for an extern op followed by injective operations.
 
index 2aff96f..6f1013c 100644 (file)
@@ -20,7 +20,6 @@ from __future__ import absolute_import as _abs
 
 import tvm
 
-@tvm.target.override_native_generic_func("schedule_injective_from_existing")
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -36,10 +35,9 @@ def schedule_injective_from_existing(sch, out):
     sch: Schedule
          The updated schedule.
     """
-    sch[out].fuse(s[out].op.axis)
+    sch[out].fuse(*sch[out].op.axis)
     return sch
 
-@tvm.target.override_native_generic_func("schedule_injective")
 def schedule_injective(outs):
     """Schedule for injective op.
 
@@ -64,22 +62,5 @@ def schedule_injective(outs):
     schedule_injective_from_existing(s, x)
     return s
 
-@tvm.target.generic_func
-def schedule_concatenate(outs):
-    """Schedule for concatenate op.
-
-    Parameters
-    ----------
-    outs: Array of Tensor
-          The computation graph description of reduce in the format
-          of an array of tensors.
-
-    Returns
-    -------
-    sch: Schedule
-        The computation schedule for the op.
-    """
-    return schedule_injective(outs)
-
 schedule_elemwise = schedule_injective
 schedule_broadcast = schedule_injective
index 8831829..ba50a8b 100644 (file)
@@ -18,7 +18,6 @@
 """Generic nn operators"""
 from __future__ import absolute_import as _abs
 import tvm
-from .. import cpp
 
 def _default_schedule(outs, auto_inline):
     """Default schedule for llvm."""
@@ -34,7 +33,6 @@ def _default_schedule(outs, auto_inline):
     return s
 
 
-@tvm.target.generic_func
 def schedule_conv1d_ncw(outs):
     """Schedule for conv1d_ncw
 
@@ -52,7 +50,6 @@ def schedule_conv1d_ncw(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_conv1d_nwc(outs):
     """Schedule for conv1d_nwc
 
@@ -70,7 +67,6 @@ def schedule_conv1d_nwc(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_conv2d_hwcn(outs):
     """Schedule for conv2d_hwcn
 
@@ -88,7 +84,6 @@ def schedule_conv2d_hwcn(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_conv2d_nchw(outs):
     """Schedule for conv2d_nchw
 
@@ -106,7 +101,6 @@ def schedule_conv2d_nchw(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_conv2d_nhwc_pack(outs):
     """Schedule for conv2d_nhwc_pack
 
@@ -124,7 +118,6 @@ def schedule_conv2d_nhwc_pack(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_conv2d_nhwc(outs):
     """Schedule for conv2d_nhwc
 
@@ -142,7 +135,6 @@ def schedule_conv2d_nhwc(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_conv2d_NCHWc(outs):
     """Schedule for conv2d_NCHW[x]c
 
@@ -161,7 +153,6 @@ def schedule_conv2d_NCHWc(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_conv2d_NCHWc_int8(outs):
     """Schedule for conv2d_NCHW[x]c_int8
 
@@ -180,7 +171,6 @@ def schedule_conv2d_NCHWc_int8(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_conv2d_winograd_weight_transform(outs):
     """Schedule for weight transformation of winograd
 
@@ -210,7 +200,6 @@ def schedule_conv2d_winograd_weight_transform(outs):
     return s
 
 
-@tvm.target.generic_func
 def schedule_conv2d_winograd_without_weight_transform(outs):
     """Schedule for winograd without weight transformation
 
@@ -228,7 +217,6 @@ def schedule_conv2d_winograd_without_weight_transform(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_conv2d_winograd_nnpack_weight_transform(outs):
     """Schedule for weight transformation of winograd
      Parameters
@@ -245,23 +233,7 @@ def schedule_conv2d_winograd_nnpack_weight_transform(outs):
     s = tvm.create_schedule([x.op for x in outs])
     return s
 
-@tvm.target.generic_func
-def schedule_conv2d_winograd_nnpack_without_weight_transform(outs):
-    """Schedule for winograd without weight transformation
-     Parameters
-    ----------
-    outs: Array of Tensor
-          The computation graph description of this operator
-          in the format of an array of tensors.
-     Returns
-    -------
-    sch: Schedule
-        The computation schedule for the op.
-    """
-    return _default_schedule(outs, False)
 
-
-@tvm.target.generic_func
 def schedule_conv3d_ncdhw(outs):
     """Schedule for conv3d_ncdhw
 
@@ -278,7 +250,6 @@ def schedule_conv3d_ncdhw(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
 def schedule_conv3d_ndhwc(outs):
     """Schedule for conv3d_ndhwc
 
@@ -295,7 +266,6 @@ def schedule_conv3d_ndhwc(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
 def schedule_conv2d_transpose_nchw(outs):
     """Schedule for conv2d_transpose_nchw
 
@@ -313,7 +283,6 @@ def schedule_conv2d_transpose_nchw(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_conv1d_transpose_ncw(outs):
     """Schedule for conv1d_transpose_ncw
 
@@ -331,7 +300,6 @@ def schedule_conv1d_transpose_ncw(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_depthwise_conv2d_nchw(outs):
     """Schedule for depthwise_conv2d_nchw
 
@@ -349,7 +317,6 @@ def schedule_depthwise_conv2d_nchw(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_depthwise_conv2d_nhwc(outs):
     """Schedule for depthwise_conv2d_nhwc
     Parameters
@@ -366,7 +333,6 @@ def schedule_depthwise_conv2d_nhwc(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_depthwise_conv2d_NCHWc(outs):
     """Schedule for depthwise_conv2d_NCHWc
     Parameters
@@ -383,7 +349,6 @@ def schedule_depthwise_conv2d_NCHWc(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_group_conv2d_nchw(outs):
     """Schedule for group_conv2d_nchw
 
@@ -401,7 +366,6 @@ def schedule_group_conv2d_nchw(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_deformable_conv2d_nchw(outs):
     """Schedule for deformable_conv2d_nchw
 
@@ -419,7 +383,6 @@ def schedule_deformable_conv2d_nchw(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_bitserial_conv2d_nchw(outs):
     """Schedule for bitserial_conv2d_nchw
 
@@ -437,7 +400,6 @@ def schedule_bitserial_conv2d_nchw(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_bitserial_conv2d_nhwc(outs):
     """Schedule for bitserial_conv2d_nhwc
 
@@ -455,7 +417,6 @@ def schedule_bitserial_conv2d_nhwc(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_bitserial_dense(outs):
     """Schedule for bitserial_dense
     Parameters
@@ -471,7 +432,6 @@ def schedule_bitserial_dense(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.override_native_generic_func("schedule_reduce")
 def schedule_reduce(outs):
     """Schedule for reduction
 
@@ -489,7 +449,6 @@ def schedule_reduce(outs):
     return _default_schedule(outs, True)
 
 
-@tvm.target.override_native_generic_func("schedule_softmax")
 def schedule_softmax(outs):
     """Schedule for softmax
 
@@ -507,7 +466,6 @@ def schedule_softmax(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.override_native_generic_func("schedule_dense")
 def schedule_dense(outs):
     """Schedule for dense
 
@@ -525,7 +483,6 @@ def schedule_dense(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.override_native_generic_func("schedule_pool")
 def schedule_pool(outs, layout):
     """Schedule for pool
 
@@ -546,7 +503,6 @@ def schedule_pool(outs, layout):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_pool_grad(outs):
     """Schedule for pool_grad
 
@@ -559,7 +515,6 @@ def schedule_pool_grad(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.override_native_generic_func("schedule_adaptive_pool")
 def schedule_adaptive_pool(outs):
     """Schedule for adaptive pool
 
@@ -577,7 +532,6 @@ def schedule_adaptive_pool(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.override_native_generic_func("schedule_binarize_pack")
 def schedule_binarize_pack(outs):
     """Schedule for binarize_pack
 
@@ -595,7 +549,6 @@ def schedule_binarize_pack(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.override_native_generic_func("schedule_bitpack")
 def schedule_bitpack(outs):
     """Schedule for bitpack
     Parameters
@@ -612,7 +565,6 @@ def schedule_bitpack(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.override_native_generic_func("schedule_binary_dense")
 def schedule_binary_dense(outs):
     """Schedule for binary_dense
 
@@ -630,7 +582,6 @@ def schedule_binary_dense(outs):
     return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_lrn(outs):
     """Schedule for lrn
 
@@ -645,12 +596,9 @@ def schedule_lrn(outs):
     sch: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.Target.current(allow_none=False)
-    cpp_target = cpp.TEST_create_target(target.target_name)
-    return cpp.generic.default_schedule(cpp_target, outs, False)
+    return _default_schedule(outs, False)
 
 
-@tvm.target.generic_func
 def schedule_sparse_dense(outs):
     """Schedule for sparse_dense
 
@@ -667,7 +615,7 @@ def schedule_sparse_dense(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
+
 def schedule_sparse_transpose(outs):
     """Schedule for sparse_transpose
 
@@ -684,8 +632,19 @@ def schedule_sparse_transpose(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
+
 def schedule_batch_matmul(outs):
-    target = tvm.target.Target.current(allow_none=False)
-    cpp_target = cpp.TEST_create_target(target.target_name)
-    return cpp.generic.default_schedule(cpp_target, outs, False)
+    """Schedule for batch_matmul
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of sparse_transpose
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
index 41045e4..69f2366 100644 (file)
 # pylint: disable=invalid-name, no-member
 """Generic search operators"""
 from __future__ import absolute_import as _abs
-import tvm
 from .vision import _default_schedule
 
-@tvm.target.generic_func
 def schedule_argwhere(outs):
     """Schedule for argwhere operator.
 
index 5462f2c..9eca588 100644 (file)
 # pylint: disable=invalid-name, no-member
 """Generic vision operators"""
 from __future__ import absolute_import as _abs
-import tvm
 from .vision import _default_schedule
 
-@tvm.target.generic_func
 def schedule_argsort(outs):
     """Schedule for argsort operator.
 
@@ -37,7 +35,6 @@ def schedule_argsort(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
 def schedule_topk(outs):
     """Schedule for topk operator.
 
index 85d9153..d6e80df 100644 (file)
@@ -33,7 +33,6 @@ def _default_schedule(outs, auto_inline):
         s[x].fuse(s[x].op.axis)
     return s
 
-@tvm.target.generic_func
 def schedule_reorg(outs):
     """Schedule for reorg
 
@@ -52,7 +51,6 @@ def schedule_reorg(outs):
     cpp_target = cpp.TEST_create_target(target.target_name)
     return cpp.generic.default_schedule(cpp_target, outs, False)
 
-@tvm.target.generic_func
 def schedule_get_valid_counts(outs):
     """Schedule for get_valid_counts
 
@@ -69,7 +67,6 @@ def schedule_get_valid_counts(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
 def schedule_nms(outs):
     """Schedule for non-maximum suppression
 
@@ -86,7 +83,6 @@ def schedule_nms(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
 def schedule_multibox_prior(outs):
     """Schedule for multibox_prior
 
@@ -103,7 +99,6 @@ def schedule_multibox_prior(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
 def schedule_multibox_transform_loc(outs):
     """Schedule for multibox_transform_loc
 
@@ -121,7 +116,6 @@ def schedule_multibox_transform_loc(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
 def schedule_multibox_detection(outs):
     """Schedule for multibox_detection
 
@@ -138,7 +132,6 @@ def schedule_multibox_detection(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
 def schedule_roi_align(outs):
     """Schedule for roi_align
 
@@ -155,7 +148,6 @@ def schedule_roi_align(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
 def schedule_roi_pool(outs):
     """Schedule for roi_align
 
@@ -172,7 +164,6 @@ def schedule_roi_pool(outs):
     """
     return _default_schedule(outs, False)
 
-@tvm.target.generic_func
 def schedule_proposal(outs):
     """Schedule for proposal operator.
 
index de58428..d4ccf41 100644 (file)
@@ -17,9 +17,7 @@
 # pylint: disable=invalid-name, unused-variable,
 """Schedule for composition of injective operator"""
 import tvm
-from .. import generic
 
-@generic.schedule_injective_from_existing.register(["hls"])
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -40,7 +38,6 @@ def schedule_injective_from_existing(sch, out):
     sch[out].bind(px, tvm.thread_axis("pipeline"))
     return sch
 
-@generic.schedule_injective.register(["hls"])
 def schedule_injective(outs):
     """Schedule for injective op.
 
index d73cb9c..06cf329 100644 (file)
@@ -19,7 +19,6 @@
 from __future__ import absolute_import as _abs
 import tvm
 from .. import tag
-from .. import generic
 
 
 def _schedule_conv2d(outs):
@@ -52,7 +51,6 @@ def _schedule_conv2d(outs):
     return s
 
 
-@generic.schedule_conv2d_nchw.register(["hls"])
 def schedule_conv2d_nchw(outs):
     """Schedule for conv2d_nchw
 
@@ -70,7 +68,6 @@ def schedule_conv2d_nchw(outs):
     return _schedule_conv2d(outs)
 
 
-@generic.schedule_conv2d_nhwc.register(["hls"])
 def schedule_conv2d_nhwc(outs):
     """Schedule for conv2d_nhwc
 
@@ -88,7 +85,6 @@ def schedule_conv2d_nhwc(outs):
     return _schedule_conv2d(outs)
 
 
-@generic.schedule_conv2d_NCHWc.register(["hls"])
 def schedule_conv2d_NCHWc(outs):
     """Schedule for conv2d_NCHW[x]c
 
@@ -106,7 +102,6 @@ def schedule_conv2d_NCHWc(outs):
     return _schedule_conv2d(outs)
 
 
-@generic.schedule_conv2d_transpose_nchw.register(["hls"])
 def schedule_conv2d_transpose_nchw(outs):
     """Schedule for conv2d_transpose_nchw
 
@@ -124,7 +119,6 @@ def schedule_conv2d_transpose_nchw(outs):
     return _schedule_conv2d(outs)
 
 
-@generic.schedule_depthwise_conv2d_nchw.register(["hls"])
 def schedule_depthwise_conv2d_nchw(outs):
     """Schedule for depthwise_conv2d_nchw
 
@@ -142,7 +136,6 @@ def schedule_depthwise_conv2d_nchw(outs):
     return _schedule_conv2d(outs)
 
 
-@generic.schedule_depthwise_conv2d_nhwc.register(["hls"])
 def schedule_depthwise_conv2d_nhwc(outs):
     """Schedule for depthwise_conv2d_nhwc
     Parameters
@@ -158,7 +151,6 @@ def schedule_depthwise_conv2d_nhwc(outs):
     """
     return _schedule_conv2d(outs)
 
-@generic.schedule_bitserial_conv2d_nchw.register(["hls"])
 def schedule_bitserial_conv2d_nchw(outs):
     """Schedule for bitserial_conv2d_nchw
 
@@ -176,7 +168,6 @@ def schedule_bitserial_conv2d_nchw(outs):
     return _schedule_conv2d(outs)
 
 
-@generic.schedule_bitserial_conv2d_nhwc.register(["hls"])
 def schedule_bitserial_conv2d_nhwc(outs):
     """Schedule for bitserial_conv2d_nhwc
 
@@ -194,7 +185,6 @@ def schedule_bitserial_conv2d_nhwc(outs):
     return _schedule_conv2d(outs)
 
 
-@generic.schedule_reduce.register(["hls"])
 def schedule_reduce(outs):
     """Schedule for reduction
 
@@ -241,7 +231,6 @@ def schedule_reduce(outs):
     return s
 
 
-@generic.schedule_softmax.register(["hls"])
 def schedule_softmax(outs):
     """Schedule for softmax
 
@@ -286,7 +275,6 @@ def schedule_softmax(outs):
     return s
 
 
-@generic.schedule_dense.register(["hls"])
 def schedule_dense(outs):
     """Schedule for dense
 
@@ -330,7 +318,6 @@ def schedule_dense(outs):
     return s
 
 
-@generic.schedule_pool.register(["hls"])
 def schedule_pool(outs, layout):
     """Schedule for pool
 
@@ -374,7 +361,6 @@ def schedule_pool(outs, layout):
     return s
 
 
-@generic.schedule_adaptive_pool.register(["hls"])
 def schedule_adaptive_pool(outs):
     """Schedule for adaptive_pool
 
index 5223d2d..5f82fe7 100644 (file)
@@ -20,3 +20,5 @@
 from __future__ import absolute_import as _abs
 
 from .conv2d import *
+from . import conv2d_alter_op
+from .depthwise_conv2d import *
index 65ea590..8993063 100644 (file)
 from __future__ import absolute_import as _abs
 
 import tvm
-
 from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
-from tvm.autotvm.task.topi_integration import deserialize_args
-from tvm.autotvm.task import get_config
-from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, conv2d_infer_layout
-from ..nn.util import get_pad_tuple
-from ..nn.depthwise_conv2d import depthwise_conv2d_nchw
-from ..nn import pad
-from .. import tag
-from .. import generic
+
+from .. import nn
 from .. import util
-from ..util import simplify, get_const_tuple
+from ..util import simplify, get_const_tuple, traverse_inline
 
 
 def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False):
@@ -73,17 +66,12 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
     cfg["block_ow"] = OtherOptionEntity(block_ow)
 
 
-def _create_schedule_template(cfg, data, kernel, strides, padding, dilation, layout):
+def _create_schedule_template(cfg, dshape, kshape, strides, padding, dilation):
     """Create schedule configuration from input arguments"""
-    dshape = get_const_tuple(data.shape)
-    kshape = get_const_tuple(kernel.shape)
-    if layout == 'NCHW':
-        n, ic, h, w = dshape
-        oc, _, kh, kw = kshape
-    else:
-        raise ValueError("Not support this layout {} with "
-                         "schedule template.".format(layout))
-    pt, pl, pb, pr = get_pad_tuple(padding, kernel)
+    n, ic, h, w = dshape
+    oc, _, kh, kw = kshape
+
+    pt, pl, pb, pr = nn.get_pad_tuple(padding, (kh, kw))
     sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
     oh = (h - kh + pt + pb) // sh + 1
     ow = (w - kw + pl + pr) // sw + 1
@@ -155,112 +143,31 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None
     s[tensor].bind(xi, thread_x)
     return xi, thread_z, thread_y, thread_x
 
-# Define template function for autotvm task
-# We define schedule template in this function instead of
-# declaration function since actual input arguments need
-# to be altered by the schedule selected.
-@autotvm.task.register("topi_intel_graphics_conv2d_NCHWc")
-def __topi_nn_conv2d_NCHWc(*args, **kwargs):
-    assert not kwargs, "Do not support kwargs in template function call"
-    data, kernel, strides, padding, dilation, layout, dtype = deserialize_args(args)
-    raw_data_shape = get_const_tuple(data.shape)
-    raw_kernel_shape = get_const_tuple(kernel.shape)
-
-    # get config here
-    cfg = get_config()
-    _create_schedule_template(cfg, data, kernel, strides, padding, dilation, layout)
-    cfg.add_flop(1)
-
-    # change shape with the value in config
-    ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1]
-    oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1]
-
-    new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
-                      raw_data_shape[2], raw_data_shape[3], ic_bn)
-    new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
-                        raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
-    new_data = tvm.placeholder(new_data_shape, data.dtype)
-    new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
-
-    C = _decl_cl_spatialpack_NCHWc(cfg, new_data, new_kernel, strides, padding, dilation, dtype)
-    s = _schedule_conv2d_NCHWc(cfg, [C])
-
-    return s, [new_data, new_kernel, C]
-
-@conv2d_alter_layout.register(["intel_graphics"])
-def _alter_conv2d_layout(attrs, inputs, tinfo, F):
-    copy_inputs = list(inputs)
-    new_attrs = {k : attrs[k] for k in attrs.keys()}
-
-    if F.__name__ == 'tvm.relay.op':
-        # Derive channels for frontends (e.g ONNX) that miss "channel" field.
-        new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
-
-    data, kernel = tinfo[0], tinfo[1]
-    batch_size, in_channel, height, width = get_const_tuple(data.shape)
-
-    groups = attrs.get_int("groups")
-    out_channel = attrs.get_int("channels")
-    padding = attrs.get_int_tuple("padding")
-    strides = attrs.get_int_tuple("strides")
-    dilation = attrs.get_int_tuple("dilation")
-    out_dtype = attrs["out_dtype"]
-
-    layout_name = 'data_layout'
-    layout = attrs[layout_name]
-    kh, kw = attrs.get_int_tuple("kernel_size")
-
-    dtype = data.dtype
-    out_dtype = dtype if out_dtype in ("same", "") else out_dtype
-    is_depthwise = groups == in_channel and groups == out_channel
-
-    # only optimize for NCHW
-    if layout != 'NCHW':
-        return None
-    if groups != 1 and not is_depthwise:
-        return None
-
-    dispatch_ctx = autotvm.task.DispatchContext.current
-    target = tvm.target.Target.current()
-
-    # query schedule and fallback if necessary
-    workload = autotvm.task.args_to_workload(
-        [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \
-        if is_depthwise else \
-        autotvm.task.args_to_workload(
-            [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
-    if is_depthwise:
-        return None
-    cfg = dispatch_ctx.query(target, workload)
-    if cfg.is_fallback:
-        _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)
 
-    ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1]
-    oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1]
+def _pack_data(data, kernel, ic_bn, oc_bn):
+    n, _, ih, iw = get_const_tuple(data.shape)
+    oc, ic, kh, kw = get_const_tuple(kernel.shape)
 
-    new_attrs[layout_name] = 'NCHW%dc' % ic_bn
-    new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+    ic_chunk = ic // ic_bn
+    oc_chunk = oc // oc_bn
 
-    new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
-                               dtype=data.dtype)
+    data = tvm.compute((n, ic_chunk, ih, iw, ic_bn),
+                       lambda bs, c, h, w, vc: data[bs, c*ic_bn + vc, h, w],
+                       name="data_vec")
 
-    out_channel, _, kh, kw = get_const_tuple(kernel.shape)
-    # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
-    new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
+    kernel = tvm.compute(
+        (oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn),
+        lambda occ, icc, k_h, k_w, icb, ocb:
+        kernel[occ * oc_bn + ocb,
+               icc * ic_bn + icb, k_h, k_w],
+        name="kernel_vec")
 
-    # Store altered operator's config
-    new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
-                                 dtype=kernel.dtype)
-    new_workload = autotvm.task.args_to_workload(
-        [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
-         new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
+    return data, kernel
 
-    dispatch_ctx.update(target, new_workload, cfg)
-    return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
 
-@autotvm.register_topi_compute(conv2d_NCHWc, 'intel_graphics', 'direct')
-def _decl_conv2d(cfg, data, kernel, strides, padding, dilation,
-                 layout, out_layout, out_dtype='float32'):
+@autotvm.register_topi_compute("conv2d_NCHWc.intel_graphics")
+def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout,
+                 out_layout, out_dtype='float32'):
     """Conv2D operator for Intel Graphics backend.
 
     Parameters
@@ -285,96 +192,49 @@ def _decl_conv2d(cfg, data, kernel, strides, padding, dilation,
     output : tvm.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
+    if len(data.shape) == 5:
+        batch, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
+        oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
+        in_channel = ic_chunk * ic_bn
+        num_filter = oc_chunk * oc_bn
+    else:
+        batch, in_channel, ih, iw = get_const_tuple(data.shape)
+        num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
+
     dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
+    pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple(
+        padding, (kernel_height, kernel_width))
     assert (dh, dw) == (1, 1), "Does not support dilation"
+    if isinstance(strides, (tuple, list)):
+        stride_h, stride_w = strides
+    else:
+        stride_h, stride_w = strides, strides
+
+    data_shape = (batch, in_channel, ih, iw)
+    kernel_shape = (num_filter, in_channel, kernel_height, kernel_width)
+    _create_schedule_template(cfg, data_shape, kernel_shape, strides, padding, dilation)
 
-    n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
-    oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
-    in_channel = ic_chunk * ic_bn
-    num_filter = oc_chunk * oc_bn
     if cfg.is_fallback:
-        _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
+        _get_default_config(cfg, tvm.placeholder((batch, in_channel, ih, iw), dtype=data.dtype),
                             tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
                                             dtype=kernel.dtype),
                             strides, padding, out_dtype)
 
-    return _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, out_dtype)
-
-
-@conv2d_infer_layout.register("intel_graphics")
-def _conv2d_infer_layout(workload, cfg):
-    _, data, kernel, strides, padding, dilation, layout, dtype = workload
-    batch_size, in_channel, in_height, in_width = data[:-1]
-    out_channel, _, k_height, k_width = kernel[:-1]
-    out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
-    out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
-    tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
-    in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
-    in_layout = "NCHW%dc" % tile_ic
-    out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
-    out_layout = "NCHW%dc" % tile_oc
-    return ((in_shape, in_layout),), ((out_shape, out_layout),)
-
-
-@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'intel_graphics', ['direct'])
-def _schedule_conv2d_NCHWc(cfg, outs):
-    """Schedule for conv2d_nchw for Intel Graphics
-
-    Parameters
-    ----------
-    outs: Array of Tensor
-        The computation graph description of conv2d_nchw
-        in the format of an array of tensors.
-
-    Returns
-    -------
-    s: Schedule
-        The computation schedule for conv2d_nchw.
-    """
-    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    s = tvm.create_schedule([x.op for x in outs])
-    scheduled_ops = []
-
-    def traverse(op):
-        """inline all one-to-one-mapping operators except the last stage (output)"""
-        if tag.is_injective(op.tag):
-            if op not in s.outputs:
-                s[op].compute_inline()
-            for tensor in op.input_tensors:
-                if tensor.op.input_tensors and tensor.op not in scheduled_ops:
-                    traverse(tensor.op)
-        if "conv" in op.tag:
-            _schedule_cl_spatialpack_NCHWc(cfg, s, op)
-
-        scheduled_ops.append(op)
-
-    traverse(outs[0].op)
-
-    return s
-
-def _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, out_dtype='float16'):
-    batch, in_channel, in_height, in_width, vc = [util.get_const_int(x) for x in data.shape]
-    in_channel *= vc
-    num_filter, channel, kernel_h, kernel_w, ci, co = [util.get_const_int(x) for x in kernel.shape]
-    num_filter *= co
-    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, kernel)
+    ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1]
+    oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1]
 
-    ic_bn = vc
-    assert vc == ci
-
-    if isinstance(strides, (tuple, list)):
-        stride_h, stride_w = strides
-    else:
-        stride_h, stride_w = strides, strides
+    # Pack data if raw 4-D data is provided.
+    if len(data.shape) == 4:
+        data, kernel = _pack_data(data, kernel, ic_bn, oc_bn)
 
     out_channel = num_filter
-    out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
-    out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
-    oshape = (batch, out_channel // co, out_height, out_width, co)
+    out_height = simplify((ih - kernel_height + pad_top + pad_down) // stride_h + 1)
+    out_width = simplify((iw - kernel_width + pad_left + pad_right) // stride_w + 1)
+    oshape = (batch, out_channel // oc_bn, out_height, out_width, oc_bn)
 
     rc = tvm.reduce_axis((0, in_channel), name='rc')
-    ry = tvm.reduce_axis((0, kernel_h), name='ry')
-    rx = tvm.reduce_axis((0, kernel_w), name='rx')
+    ry = tvm.reduce_axis((0, kernel_height), name='ry')
+    rx = tvm.reduce_axis((0, kernel_width), name='rx')
 
     block_h = cfg["block_oh"].val
     block_w = cfg["block_ow"].val
@@ -388,7 +248,7 @@ def _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, ou
     if out_width % block_w != 0:
         c_w = (out_width // block_w + 1) * block_w
 
-    cshape = (batch, out_channel // co, c_h, c_w, co)
+    cshape = (batch, out_channel // oc_bn, c_h, c_w, oc_bn)
 
     pad_before = [0, 0, pad_top, pad_left, 0]
     pad_after = [0, 0, pad_down + c_h - out_height, pad_right + \
@@ -397,7 +257,7 @@ def _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, ou
              or pad_right + c_w - out_width != 0)
     DOUNPACK = (c_h - out_height != 0 or c_w - out_width != 0)
     if DOPAD:
-        temp = pad(data, pad_before, pad_after, name="pad_temp")
+        temp = nn.pad(data, pad_before, pad_after, name="pad_temp")
     else:
         temp = data
 
@@ -406,33 +266,53 @@ def _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, ou
         lambda nn, ff, yy, xx, ff_v: \
             tvm.sum(
                 temp[nn, rc//ic_bn, yy * stride_h + ry, xx * stride_w + rx, rc%ic_bn]. \
-                        astype(out_dtype) *
+                astype(out_dtype) *
                 kernel[ff, rc//ic_bn, ry, rx, rc%ic_bn, ff_v].astype(out_dtype),
-                axis=[rc, ry, rx]), tag="conv", name='conv')
+                axis=[rc, ry, rx]), tag="conv2d_NCHWc", name='conv2d_NCHWc')
 
     if DOUNPACK:
         output = tvm.compute(
             oshape,
             lambda nn, ff, yy, xx, ff_v:
             conv[nn][ff][yy][xx][ff_v],
-            name='output_unpack', tag="conv_unpack")
+            name='output_unpack', tag="conv2d_NCHWc_unpack")
     else:
         output = conv
 
-
     return output
 
 
+@autotvm.register_topi_schedule("conv2d_NCHWc.intel_graphics")
+def schedule_conv2d_NCHWc(cfg, outs):
+    """Schedule for conv2d_nchw for Intel Graphics
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of conv2d_nchw
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for conv2d_nchw.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        """inline all one-to-one-mapping operators except the last stage (output)"""
+        if "conv2d_NCHWc" in op.tag:
+            _schedule_cl_spatialpack_NCHWc(cfg, s, op)
+
+    traverse_inline(s, outs[0].op, _callback)
+
+    return s
+
+
 def _schedule_cl_spatialpack_NCHWc(cfg, s, op):
     output = op.output(0)
-    conv = op.input_tensors[0]
-    if conv.op.name == "conv":
-        temp = s[conv].op.input_tensors[0]
-        kernel = s[conv].op.input_tensors[1]
-        temp_W = s.cache_read(temp, "warp", [conv])
-        conv_L = s.cache_write(conv, "local")
-        SCHEDULE_OUTPUT = True
-    else:
+    if op.name == "conv2d_NCHWc":
         temp = op.input_tensors[0]
         kernel = op.input_tensors[1]
         temp_W = s.cache_read(temp, "warp", [output])
@@ -443,8 +323,32 @@ def _schedule_cl_spatialpack_NCHWc(cfg, s, op):
             s[output].compute_inline()
             conv = s.outputs[0]
         SCHEDULE_OUTPUT = False
+    else: # conv2d_NCHWc_unpack
+        conv = op.input_tensors[0]
+        temp = s[conv].op.input_tensors[0]
+        kernel = s[conv].op.input_tensors[1]
+        temp_W = s.cache_read(temp, "warp", [conv])
+        conv_L = s.cache_write(conv, "local")
+        SCHEDULE_OUTPUT = True
     kernel_L = s.cache_read(kernel, "local", [conv_L])
 
+    if temp.name == "pad_temp":
+        data = temp.op.input_tensors[0]
+        # TODO(@Laurawly): Do we need to schedule pad op here?
+    else:
+        data = temp
+
+    if autotvm.GLOBAL_SCOPE.in_tuning:
+        # only in autotuning, input data of conv2d_NCHWc will be 4-D.
+        # skip this part during tuning to make records accurate.
+        # this part will be folded during Relay fold_constant pass.
+        s[data].pragma(s[data].op.axis[0], "debug_skip_region")
+        s[kernel].pragma(s[kernel].op.axis[0], "debug_skip_region")
+    elif isinstance(kernel.op, tvm.tensor.ComputeOp) and kernel.name == "kernel_vec":
+        # data and kernel are not pre-computed, schedule layout transform here.
+        # TODO(@Laurawly): Add schedule for data and kernel pack
+        pass
+
     OUTPUT_BLOCK_HEIGHT = cfg["block_oh"].val
     OUTPUT_BLOCK_WIDTH = cfg["block_ow"].val
 
@@ -515,19 +419,7 @@ def _schedule_cl_spatialpack_NCHWc(cfg, s, op):
         tile_and_bind3d(s, out, w, h, vc, 4, 8, 8)
 
 
-def conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype):
-    """convert argument to workload"""
-    if len(kernel.shape) == 4:
-        raw_kernel = kernel
-    else:  # the input kernel is transformed by alter_op_layout
-        shape = get_const_tuple(kernel.shape)
-        raw_kernel = tvm.placeholder((shape[0] * shape[4], shape[1], shape[2], shape[3]),
-                                     dtype=kernel.dtype)
-    return ('conv2d', ) + autotvm.task.args_to_workload(
-        [data, raw_kernel, strides, padding, layout, out_dtype])
-
-@autotvm.register_topi_compute(conv2d, 'intel_graphics', 'direct')
-def decl_conv2d(cfg, data, kernel, stride, padding, dilation, layout='NCHW', out_dtype='float32'):
+def conv2d_nchw(data, kernel, stride, padding, dilation, out_dtype='float32'):
     """Conv2D operator for Intel Graphics backend.
 
     Parameters
@@ -540,21 +432,18 @@ def decl_conv2d(cfg, data, kernel, stride, padding, dilation, layout='NCHW', out
         stride size, or [stride_height, stride_width]
     padding : int or a list/tuple of two ints
         padding size, or [pad_height, pad_width]
-    layout : str
-        layout of data
     Returns
     -------
     output : tvm.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    assert layout == 'NCHW', "only support NCHW convolution on intel gpu"
     assert data.shape[0].value == 1, "only support batch size=1 convolution on intel gpu"
     assert data.dtype == kernel.dtype, "Do not support inputs with different data types now."
 
-    return _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype)
+    return _decl_cl_spatialpack(data, kernel, stride, padding, out_dtype)
+
 
-@autotvm.task.register_topi_schedule(generic.schedule_conv2d_nchw, 'intel_graphics', ['direct'])
-def schedule_conv2d_nchw(cfg, outs):
+def schedule_conv2d_nchw(outs):
     """Schedule for conv2d_nchw for Intel Graphics
 
     Parameters
@@ -569,28 +458,20 @@ def schedule_conv2d_nchw(cfg, outs):
     """
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
-    scheduled_ops = []
 
-    def traverse(op):
+    def _callback(op):
         """inline all one-to-one-mapping operators except the last stage (output)"""
-        if tag.is_broadcast(op.tag):
-            if op not in s.outputs:
-                s[op].compute_inline()
-            for tensor in op.input_tensors:
-                if tensor.op.input_tensors and tensor.op not in scheduled_ops:
-                    traverse(tensor.op)
         if 'conv2d' in op.tag:
-            _schedule_cl_spatialpack(cfg, s, op)
-
-        scheduled_ops.append(op)
+            _schedule_cl_spatialpack(s, op)
 
-    traverse(outs[0].op)
+    traverse_inline(s, outs[0].op, _callback)
     return s
 
-def _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype='float16'):
+
+def _decl_cl_spatialpack(data, kernel, stride, padding, out_dtype='float16'):
     batch, in_channel, in_height, in_width = [util.get_const_int(x) for x in data.shape]
     num_filter, channel, kernel_h, kernel_w = [util.get_const_int(x) for x in kernel.shape]
-    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, kernel)
+    pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple(padding, (kernel_h, kernel_w))
 
     if isinstance(stride, (tuple, list)):
         stride_h, stride_w = stride
@@ -606,8 +487,6 @@ def _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype='
     ry = tvm.reduce_axis((0, kernel_h), name='ry')
     rx = tvm.reduce_axis((0, kernel_w), name='rx')
 
-    block_w = 1
-    block_h = 1
     if stride_h == 2:
         if num_filter + kernel_h == 515:
             block_h = 4
@@ -640,7 +519,7 @@ def _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype='
 
     pad_before = [0, 0, pad_top, pad_left]
     pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w]
-    temp = pad(data, pad_before, pad_after, name="pad_temp")
+    temp = nn.pad(data, pad_before, pad_after, name="pad_temp")
 
     nv = 16
     if num_filter % nv != 0:
@@ -667,13 +546,12 @@ def _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype='
         oshape,
         lambda nn, ff, yy, xx:
         conv[nn][ff//nv][yy][xx][ff%nv],
-        name='output_unpack', tag='conv2d',
-        attrs={'workload': conv_arg_to_workload(data, kernel, stride, padding,
-                                                layout, out_dtype)})
+        name='output_unpack', tag='conv2d')
 
     return output
 
-def _schedule_cl_spatialpack(cfg, s, op):
+
+def _schedule_cl_spatialpack(s, op):
     output = op.output(0)
     _, _, out_height, out_width = [util.get_const_int(x) for x in output.shape]
 
@@ -742,7 +620,7 @@ def _schedule_cl_spatialpack(cfg, s, op):
     s[kernel_vec].compute_inline()
 
     # schedule kernel_L
-    if "2_14" in s[conv].op.tag:
+    if OUTPUT_BLOCK_HEIGHT == 2 and OUTPUT_BLOCK_WIDTH == 14:
         s[kernel_L].compute_at(s[conv_L], ry)
     else:
         s[kernel_L].compute_at(s[conv_L], rx)
diff --git a/topi/python/topi/intel_graphics/conv2d_alter_op.py b/topi/python/topi/intel_graphics/conv2d_alter_op.py
new file mode 100644 (file)
index 0000000..e95e59f
--- /dev/null
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Conv2D alter op and legalize functions for x86"""
+
+import tvm
+from tvm import relay
+from tvm import autotvm
+
+from ..util import get_const_tuple
+from ..nn import conv2d_alter_layout, conv2d_infer_layout
+from .conv2d import _get_default_config
+
+
+@conv2d_alter_layout.register(["intel_graphics"])
+def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
+    target = tvm.target.Target.current(allow_none=False)
+    dispatch_ctx = autotvm.task.DispatchContext.current
+    if isinstance(dispatch_ctx, autotvm.task.ApplyGraphBest):
+        cfg = dispatch_ctx.query(target, None)
+        workload = cfg.workload
+    else:
+        _, outs = relay.backend.compile_engine.select_implementation(
+            relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+        workload = autotvm.task.get_workload(outs)
+        if workload is None:
+            # The best implementation is not an AutoTVM template,
+            # we then assume it's not necessary to alter this op.
+            return None
+        cfg = dispatch_ctx.query(target, workload)
+
+    topi_tmpl = workload[0]
+    new_attrs = {k : attrs[k] for k in attrs.keys()}
+
+    padding = attrs.get_int_tuple("padding")
+    strides = attrs.get_int_tuple("strides")
+    dilation = attrs.get_int_tuple("dilation")
+    data_layout = attrs["data_layout"]
+    kernel_layout = attrs["kernel_layout"]
+    data_tensor, kernel_tensor = tinfos
+    data_dtype = data_tensor.dtype
+    kernel_dtype = kernel_tensor.dtype
+    out_dtype = out_type.dtype
+
+    if topi_tmpl == "conv2d_NCHWc.intel_graphics":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        if cfg.is_fallback:
+            _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
+                                out_dtype, False)
+        batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
+        out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
+        ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1]
+        oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1]
+
+        # update new attrs
+        new_attrs['channels'] = out_channel
+        new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
+        # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
+        new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
+        new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+
+        # Store altered operator's config
+        new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
+                                   dtype=data_dtype)
+        new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
+                                      kh, kw, ic_bn, oc_bn), dtype=kernel_dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"],
+             new_attrs["out_layout"], out_dtype], "conv2d_NCHWc.intel_graphics")
+        dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)
+
+    return None
+
+
+@conv2d_infer_layout.register("intel_graphics")
+def _conv2d_infer_layout(workload, cfg):
+    _, data, kernel, strides, padding, dilation, layout, dtype = workload
+    batch_size, in_channel, in_height, in_width = data[1]
+    out_channel, _, k_height, k_width = kernel[1]
+    out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
+    out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
+    tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+    in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
+    in_layout = "NCHW%dc" % tile_ic
+    out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
+    out_layout = "NCHW%dc" % tile_oc
+    return ((in_shape, in_layout),), ((out_shape, out_layout),)
index 97b7376..17f1943 100644 (file)
@@ -20,16 +20,17 @@ import tvm
 from tvm import autotvm
 from ..util import traverse_inline
 from .. import tag
-from .. import generic, nn
+from .. import nn
 from ..nn.depthwise_conv2d import depthwise_conv2d_infer_layout
 
 # register original implementation of depthwise_conv2d_nchw since we don't need to change this part
-autotvm.register_topi_compute(nn.depthwise_conv2d_nchw, ['intel_graphics'], 'direct',
-                              nn.depthwise_conv2d_nchw.fdefault)
+@autotvm.register_topi_compute("depthwise_conv2d_nchw.intel_graphics")
+def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype):
+    return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
 
-@autotvm.register_topi_schedule(generic.schedule_depthwise_conv2d_nchw, \
-        ['intel_graphics'], 'direct')
-def schedule_depthwise_conv2d_nchw_intel_graphics(cfg, outs):
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nchw.intel_graphics")
+def schedule_depthwise_conv2d_nchw(cfg, outs):
     """Schedule for depthwise_conv2d nchw forward.
 
     Parameters
@@ -68,7 +69,7 @@ def schedule_depthwise_conv2d_nchw_intel_graphics(cfg, outs):
             # fallback support
             if cfg.is_fallback:
                 ref_log = autotvm.tophub.load_reference_log(
-                    target.target_name, target.model, 'depthwise_conv2d_nchw', 'direct')
+                    target.target_name, target.model, 'depthwise_conv2d_nchw.intel_graphics')
                 cfg.fallback_with_reference_log(ref_log)
                 cfg['unroll_explicit'].val = 0
             ##### space definition end #####
@@ -132,7 +133,7 @@ def schedule_depthwise_conv2d_nchw_intel_graphics(cfg, outs):
     traverse_inline(s, outs[0].op, _callback)
     return s
 
-@generic.schedule_depthwise_conv2d_nhwc.register(["intel_graphics"])
+
 def schedule_depthwise_conv2d_nhwc(outs):
     """Schedule for depthwise_conv2d nhwc forward.
 
@@ -331,8 +332,8 @@ def _depthwise_conv2d_infer_layout(workload, _):
         Input shapes and layouts, and output shapes and layouts
     """
     _, data, kernel, strides, padding, _, _ = workload
-    batch_size, in_channel, in_height, in_width = data[:-1]
-    filter_channel, channel_multiplier, k_height, k_width = kernel[:-1]
+    batch_size, in_channel, in_height, in_width = data[1]
+    filter_channel, channel_multiplier, k_height, k_width = kernel[1]
     out_channel = filter_channel * channel_multiplier
     out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
     out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
index 35a86e9..f774e76 100644 (file)
 # pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
 """conv2d schedule on ARM Mali GPU"""
 import tvm
+from tvm import relay
 from tvm import autotvm
 from tvm.autotvm.task.space import get_factors
 
-from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
 from ..util import traverse_inline, get_const_int, get_const_tuple
-from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
-    get_pad_tuple, pad, conv2d_alter_layout
+from .. import nn
 from ..nn.winograd_util import winograd_transform_matrices
 
 # reuse some compute declarations from ARM CPU
-from ..arm_cpu.conv2d import _alter_conv2d_layout_arm
 from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nchw
 
 
-@autotvm.register_topi_compute(conv2d, 'mali', ['direct'])
-def conv2d_mali(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
+@autotvm.register_topi_compute("conv2d_nchw_spatial_pack.mali")
+def conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """TOPI compute callback for conv2d
 
     Parameters
@@ -57,9 +55,6 @@ def conv2d_mali(cfg, data, kernel, strides, padding, dilation, layout, out_dtype
     dilation : list of two ints
         [dilation_height, dilation_width]
 
-    layout : str
-        layout of data
-
     out_dtype: str
         The output type. This is used for mixed precision.
 
@@ -68,14 +63,11 @@ def conv2d_mali(cfg, data, kernel, strides, padding, dilation, layout, out_dtype
     output : tvm.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    if layout == 'NCHW':
-        return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
-                                        dilation, out_dtype, num_tile=3)
-    else:
-        raise ValueError("Unsupported layout {}".format(layout))
+    return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
+                                    dilation, out_dtype, num_tile=3)
 
-@autotvm.register_topi_schedule(schedule_conv2d_nchw, 'mali', ['direct', 'winograd'])
-def schedule_conv2d_nchw_mali(cfg, outs):
+@autotvm.register_topi_schedule("conv2d_nchw_spatial_pack.mali")
+def schedule_conv2d_nchw_spatial_pack(cfg, outs):
     """TOPI schedule callback for conv2d
 
     Parameters
@@ -113,9 +105,6 @@ def schedule_conv2d_nchw_mali(cfg, outs):
 
             _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec)
 
-        if 'winograd_conv2d_output' in op.tag:
-            _schedule_winograd(cfg, s, op)
-
     traverse_inline(s, outs[0].op, _callback)
     return s
 
@@ -200,13 +189,27 @@ def _pick_tile_size(data, kernel):
     else:
         return 2
 
-@autotvm.register_topi_compute(conv2d, 'mali', ['winograd'])
-def conv2d_mali_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
+
+@autotvm.register_topi_compute("conv2d_nchw_winograd.mali")
+def conv2d_nchw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype):
     tile_size = _pick_tile_size(data, kernel)
-    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
+    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype,
                           tile_size)
 
-def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
+
+@autotvm.register_topi_schedule("conv2d_nchw_winograd.mali")
+def schedule_conv2d_nchw_winograd(cfg, outs):
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if 'winograd_conv2d_output' in op.tag:
+            _schedule_winograd(cfg, s, op)
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, tile_size):
     N, CI, IH, IW = get_const_tuple(data.shape)
     if isinstance(dilation, int):
         dilation_h = dilation_w = dilation
@@ -214,9 +217,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
         dilation_h, dilation_w = dilation
 
     if len(kernel.shape) == 4:
-
         if dilation_h != 1 or dilation_w != 1:
-            kernel = dilate(kernel, (1, 1, dilation_h, dilation_w))
+            kernel = nn.dilate(kernel, (1, 1, dilation_h, dilation_w))
         pre_computed = False
         CO, _, KH, KW = get_const_tuple(kernel.shape)
     else:
@@ -226,11 +228,10 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
         CO *= VC
         KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
     HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
-    pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
+    pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
 
-    assert layout == 'NCHW'
     assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
-    data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
+    data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
 
     r = KW
     m = tile_size
@@ -420,34 +421,85 @@ def _schedule_winograd(cfg, s, op):
 
     s[Y].compute_at(s[output], tt)
 
-##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
-@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'mali', ['winograd'])
-def conv2d_winograd_ww(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
-    """TOPI compute callback"""
-    return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
-                          tile_size)
 
+##### REGISTER ALTER OP LAYOUT #####
+@nn.conv2d_alter_layout.register(["mali"])
+def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
+    target = tvm.target.Target.current(allow_none=False)
+    dispatch_ctx = autotvm.task.DispatchContext.current
+
+    _, outs = relay.backend.compile_engine.select_implementation(
+        relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+    workload = autotvm.task.get_workload(outs)
+    if workload is None:
+        # The best implementation is not an AutoTVM template,
+        # we then assume it's not necessary to alter this op.
+        return None
+    cfg = dispatch_ctx.query(target, workload)
+    if cfg.is_fallback:  # if is fallback, clear query cache and return None
+        autotvm.task.clear_fallback_cache(target, workload)
+        return None
 
-@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
-                                'mali', ['winograd'])
-def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
-    """TOPI schedule callback"""
-    s = tvm.create_schedule([x.op for x in outs])
+    topi_tmpl = workload[0]
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
 
-    def _callback(op):
-        if 'winograd_conv2d_output' in op.tag:
-            _schedule_winograd(cfg, s, op)
+    strides = attrs.get_int_tuple("strides")
+    padding = attrs.get_int_tuple("padding")
+    dilation = attrs.get_int_tuple("dilation")
+    data_layout = attrs["data_layout"]
+    kernel_layout = attrs["kernel_layout"]
+    data, kernel = tinfos
+    out_dtype = out_type.dtype
 
-    traverse_inline(s, outs[0].op, _callback)
-    return s
+    idxd = tvm.indexdiv
 
+    if topi_tmpl == "conv2d_nchw_spatial_pack.mali":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+        VC = cfg['tile_co'].size[-1]
 
-##### REGISTER ALTER OP LAYOUT #####
-@conv2d_alter_layout.register(["mali"])
-def _alter_conv2d_layout(attrs, inputs, tinfos, F):
-    try:
-        return _alter_conv2d_layout_arm(attrs, inputs, tinfos, F)
-    except KeyError:  # to filter out fallback opencl templates
+        new_attrs['kernel_layout'] = 'OIHW%do' % VC
+
+        new_data = data
+        new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, out_dtype],
+            "conv2d_nchw_spatial_pack.mali")
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return relay.nn.conv2d(*inputs, **new_attrs)
+    elif topi_tmpl == "conv2d_nchw_winograd.mali":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        N, CI, H, W = get_const_tuple(data.shape)
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+        tile_size = _pick_tile_size(data, kernel)
+        VC = cfg['tile_bna'].val
+
+        weight_expr = inputs[1]
+        weight_expr = relay.nn.contrib_conv2d_winograd_weight_transform(
+            weight_expr, tile_size=tile_size)
+        weight_expr = relay.reshape(weight_expr,
+                                    newshape=(KH + tile_size - 1,
+                                              KW + tile_size - 1,
+                                              idxd(CO, VC), VC, CI))
+        weight_expr = relay.transpose(weight_expr, axes=[0, 1, 2, 4, 3])
+
+        new_attrs['tile_size'] = tile_size
+
+        new_data = data
+        new_kernel = tvm.placeholder((KH + tile_size - 1,
+                                      KW + tile_size -1,
+                                      idxd(CO, VC), CI, VC),
+                                     kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, out_dtype],
+            'conv2d_nchw_winograd.mali')
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return relay.nn.contrib_conv2d_winograd_without_weight_transform(
+            inputs[0], weight_expr, **new_attrs)
+    else:
         return None
 
 
index 6096a99..3b233e9 100644 (file)
@@ -22,12 +22,18 @@ from __future__ import absolute_import as _abs
 import tvm
 from tvm import autotvm
 
-from .. import generic, nn
+from .. import nn
 from ..util import traverse_inline
 
-autotvm.register_topi_compute(nn.dense, 'mali', 'direct', nn.dense.fdefault)
 
-@autotvm.register_topi_schedule(generic.schedule_dense, 'mali', 'direct')
+
+@autotvm.register_topi_compute('dense.mali')
+def dense(_, data, weight, bias=None, out_dtype=None):
+    """Dense operator on Mali"""
+    return nn.dense(data, weight, bias, out_dtype)
+
+
+@autotvm.register_topi_schedule('dense.mali')
 def schedule_dense(cfg, outs):
     """Schedule for dense operator.
 
@@ -52,11 +58,11 @@ def schedule_dense(cfg, outs):
             vec_size = [1, 2, 4, 8, 16]
             max_unroll = 32
 
-            dense = op.output(0)
+            dense_out = op.output(0)
             output = outs[0]
 
             y, x = s[output].op.axis
-            c = s[dense].op.reduce_axis[0]
+            c = s[dense_out].op.reduce_axis[0]
 
             ##### space definition begin #####
             cfg.define_split('tile_y', y, num_outputs=3)
@@ -66,12 +72,12 @@ def schedule_dense(cfg, outs):
             # fallback support
             if cfg.is_fallback:
                 ref_log = autotvm.tophub.load_reference_log(
-                    'mali', 'rk3399', 'dense', 'direct')
+                    'mali', 'rk3399', 'dense.mali')
                 cfg.fallback_with_reference_log(ref_log)
             ##### space definition end #####
 
-            if dense.op in s.outputs:
-                dense = s.cache_write(output, 'local')
+            if dense_out.op in s.outputs:
+                dense_out = s.cache_write(output, 'local')
 
             by, ty, yi = cfg['tile_y'].apply(s, output, y)
             bx, tx, xi = cfg['tile_x'].apply(s, output, x)
@@ -85,23 +91,25 @@ def schedule_dense(cfg, outs):
                 s[output].unroll(yi)
             if cfg['tile_x'].size[-1] in vec_size:
                 s[output].vectorize(xi)
-            s[dense].compute_at(s[output], tx)
+            s[dense_out].compute_at(s[output], tx)
 
-            k = s[dense].op.reduce_axis[0]
-            y, x = s[dense].op.axis
-            k, k_unroll = cfg['c_unroll'].apply(s, dense, k)
-            s[dense].reorder(k, k_unroll, y, x)
-            s[dense].unroll(k_unroll)
+            k = s[dense_out].op.reduce_axis[0]
+            y, x = s[dense_out].op.axis
+            k, k_unroll = cfg['c_unroll'].apply(s, dense_out, k)
+            s[dense_out].reorder(k, k_unroll, y, x)
+            s[dense_out].unroll(k_unroll)
             if cfg['tile_y'].size[-1] < max_unroll:
-                s[dense].unroll(y)
+                s[dense_out].unroll(y)
             if cfg['tile_x'].size[-1] in vec_size:
-                s[dense].vectorize(x)
+                s[dense_out].vectorize(x)
 
     traverse_inline(s, outs[0].op, _callback)
     return s
 
+
 def fuse_and_bind(s, tensor, axis=None, num_thread=None):
     """ fuse all the axis and bind to GPU threads """
+    # TODO(@comaniac): figure out where this function is used.
     axis = axis or s[tensor].op.axis
     fused = s[tensor].fuse(*axis)
     bx, tx = s[tensor].split(fused, num_thread)
index 274b294..4ff17e5 100644 (file)
 import tvm
 from tvm import autotvm
 
-from ..generic import schedule_depthwise_conv2d_nchw
-from ..nn import depthwise_conv2d_nchw
+from .. import nn
 from ..util import traverse_inline
 
 # register original implementation of depthwise_conv2d_nchw since we don't need to change this part
-autotvm.register_topi_compute(depthwise_conv2d_nchw, 'mali', 'direct',
-                              depthwise_conv2d_nchw.fdefault)
+@autotvm.register_topi_compute("depthwise_conv2d_nchw.mali")
+def depthwise_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
+
 
 # register customized schedule for arm cpu.
-@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'mali', 'direct')
-def schedule_depthwise_conv2d_nchw_mali(cfg, outs):
+@autotvm.register_topi_schedule("depthwise_conv2d_nchw.mali")
+def schedule_depthwise_conv2d_nchw(cfg, outs):
     """Schedule depthwise conv2d
 
     Parameters
@@ -64,7 +65,7 @@ def schedule_depthwise_conv2d_nchw_mali(cfg, outs):
         # fallback support
         if cfg.is_fallback:
             ref_log = autotvm.tophub.load_reference_log(
-                'mali', 'rk3399', 'depthwise_conv2d_nchw', 'direct')
+                'mali', 'rk3399', 'depthwise_conv2d_nchw.mali')
             cfg.fallback_with_reference_log(ref_log)
         ###### space definition end ######
 
index 7b872ce..d69562c 100644 (file)
@@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs
 import tvm
 from ..util import get_const_tuple
 
-def batch_matmul_default(x, y):
+def batch_matmul(x, y):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
     data in batch.
 
@@ -48,23 +48,3 @@ def batch_matmul_default(x, y):
     return tvm.compute((batch, M, N),
                        lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
                        tag='batch_matmul')
-
-@tvm.target.generic_func
-def batch_matmul(x, y):
-    """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
-    data in batch.
-
-    Parameters
-    ----------
-    x : tvm.Tensor
-        3-D with shape [batch, M, K]
-
-    y : tvm.Tensor
-        3-D with shape [batch, N, K]
-
-    Returns
-    -------
-    output : tvm.Tensor
-        3-D with shape [batch, M, N]
-    """
-    return batch_matmul_default(x, y)
index e1f8f81..f18a5aa 100644 (file)
 """Bitserial Conv2D operators"""
 from __future__ import absolute_import as _abs
 import tvm
-from tvm import autotvm
 from .pad import pad
 from .util import get_pad_tuple
-from .bitserial_util import bitpack, binary_op_multiplier
+from .bitserial_util import bitpack
 from ..util import get_const_tuple
 
-@tvm.target.generic_func
 def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight_bits,
                           pack_dtype='uint32', out_dtype='int16', unipolar=True):
     """Bitserial Conv2D operator.
@@ -117,7 +115,6 @@ def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight
     return tvm.compute((batch, out_channel, out_height, out_width), _conv,
                        name="Conv2dOutput", tag="bitserial_conv2d_nchw")
 
-@tvm.target.generic_func
 def bitserial_conv2d_nhwc(data, kernel, stride, padding, activation_bits, weight_bits,
                           pack_dtype='uint32', out_dtype='int16', unipolar=True):
     """Bitserial Conv2D operator.
@@ -213,222 +210,6 @@ def bitserial_conv2d_nhwc(data, kernel, stride, padding, activation_bits, weight
 
     return conv
 
-@autotvm.register_topi_compute(bitserial_conv2d_nchw, ['cpu', 'arm_cpu'], 'direct')
-def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
-                      pack_dtype='uint32', out_dtype='int16', unipolar=True):
-    """ Compute convolution with pack on spatial axes. """
-    assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
-    data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
-    # Check if kernel is already bitpacked
-    if len(kernel.shape) == 4:
-        kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
-        KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape)
-    else:
-        kernel_vec = kernel
-        OCO, _, KH, KW, KB, VC = get_const_tuple(kernel_vec.shape)
-        CO = OCO * VC
-
-    IB, N, CI, H, W = get_const_tuple(data_q.shape)
-    KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape)
-
-    if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
-        TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
-    else:
-        TPAD, LPAD, DPAD, RPAD = padding
-    pad_before = [0, 0, 0, TPAD, LPAD]
-    pad_after = [0, 0, 0, DPAD, RPAD]
-
-    if isinstance(stride, (tuple, list)):
-        HSTR, WSTR = stride
-    else:
-        HSTR, WSTR = stride, stride
-    HCAT, WCAT = KH-1, KW-1
-
-    TH = H + TPAD + DPAD
-    TW = W + LPAD + RPAD
-    OH = (H + TPAD + DPAD - KH) // HSTR + 1
-    OW = (W + LPAD + RPAD - KW) // WSTR + 1
-
-     # ==================== define configuration space ====================
-    n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
-    ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
-    ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
-
-    co, vc = cfg.define_split('tile_co', co, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
-
-    cfg.define_reorder("reorder_0",
-                       [n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci],
-                       policy='interval_all', interval=(6, 11))
-    # binary ops
-    cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
-    # ====================
-
-    VC = cfg["tile_co"].size[-1]
-    VH = cfg["tile_oh"].size[-1]
-    VW = cfg["tile_ow"].size[-1]
-
-    dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB)
-    kvshape = (CO//VC, CI, KH, KW, KB, VC)
-    ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC)
-    oshape = (1, CO, OH, OW)
-
-    if (TPAD != 0 and RPAD != 0):
-        data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
-    else:
-        data_pad = data_q
-
-    data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \
-        data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec')
-
-    if len(kernel.shape) == 4:
-        kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \
-            kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec')
-
-    ci = tvm.reduce_axis((0, CI), name='ci')
-    dh = tvm.reduce_axis((0, KH), name='dh')
-    dw = tvm.reduce_axis((0, KW), name='dw')
-    b1 = tvm.reduce_axis((0, IB), name='ib')
-    b2 = tvm.reduce_axis((0, KB), name='kb')
-
-    def _conv(n, co, h, w, vh, vw, vc):
-        b1b2 = (b1+b2).astype(out_dtype)
-        if unipolar:
-            return tvm.sum((tvm.popcount(
-                data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) &
-                kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype))  -
-                            tvm.popcount(
-                                data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype)
-                                & ~kernel_vec[co, ci, dh, dw, b2, vc]).astype(out_dtype)) << b1b2,
-                           axis=[ci, dh, dw, b1, b2])
-
-        return tvm.sum((tvm.popcount(
-            data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1] &
-            kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2,
-                       axis=[ci, dh, dw, b1, b2])
-
-    conv = tvm.compute(ovshape, _conv, name='conv_out')
-    idxd = tvm.indexdiv
-    idxm = tvm.indexmod
-
-    return tvm.compute(
-        oshape, lambda n, co, h, w:
-        conv[n,
-             idxd(co, VC), idxd(h, VH), idxd(w, VW),
-             idxm(h, VH), idxm(w, VW), idxm(co, VC)],
-        name='conv_vec', tag='spatial_bitserial_conv_nchw')
-
-@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct')
-def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
-                      pack_dtype='uint32', out_dtype='int16', unipolar=True):
-    """ Compute convolution with pack on spatial axes. """
-    assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
-    data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype)
-    pack_kernel = len(kernel.shape) == 4
-
-    if pack_kernel:
-        kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype)
-    else:
-        kernel_q = kernel
-
-    KH, KW, _, CO, KB = get_const_tuple(kernel_q.shape)
-    N, H, W, CI, IB = get_const_tuple(data_q.shape)
-
-    if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
-        TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
-    else:
-        TPAD, LPAD, DPAD, RPAD = padding
-    pad_before = [0, TPAD, LPAD, 0, 0]
-    pad_after = [0, DPAD, RPAD, 0, 0]
-
-    if isinstance(stride, (tuple, list)):
-        HSTR, WSTR = stride
-    else:
-        HSTR, WSTR = stride, stride
-    HCAT, WCAT = KH-1, KW-1
-
-    PAD_H = H + (TPAD + DPAD)
-    PAD_W = W + (LPAD + RPAD)
-    OH = (PAD_H - KH) // HSTR + 1
-    OW = (PAD_W - KW) // WSTR + 1
-    oshape = (1, OH, OW, CO)
-
-    # ==================== define configuration space ====================
-    n, oh, ow, co = cfg.axis(N), cfg.axis(OH), cfg.axis(OW), cfg.axis(CO)
-    ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
-    ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
-
-    co, vc = cfg.define_split('tile_co', co, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
-                              filter=lambda x: max(x.size[1:]) <= 16)
-    cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
-    cfg.define_reorder("reorder_0",
-                       [n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci],
-                       policy='interval_all', interval=(3, 7))
-    # binary ops
-    cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
-    # ====================
-
-    VC = cfg["tile_co"].size[-1]
-    VH = cfg["tile_oh"].size[-1]
-    VW = cfg["tile_ow"].size[-1]
-
-    dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB)
-    kvshape = (CO, KH, KW, CI, VC, KB)
-    ovshape = (1, OH, OW, CO, VH, VW, VC)
-    oshape = (1, OH, OW, CO)
-
-    if (DPAD != 0 and RPAD != 0):
-        data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
-    else:
-        data_pad = data_q
-
-    data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, ci, b: \
-        data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][ci][b], name='data_vec')
-
-    kernel_vec = tvm.compute(kvshape, lambda co, dh, dw, ci, vc, b: \
-        kernel_q[dh][dw][ci][co*VC+vc][b], name='kernel_vec')
-
-    ci = tvm.reduce_axis((0, CI), name='ci')
-    dh = tvm.reduce_axis((0, KH), name='dh')
-    dw = tvm.reduce_axis((0, KW), name='dw')
-    b1 = tvm.reduce_axis((0, IB), name='ib')
-    b2 = tvm.reduce_axis((0, KB), name='kb')
-
-    def _conv(n, h, w, co, vh, vw, vc):
-        b1b2 = (b1+b2).astype(out_dtype)
-        if unipolar:
-            return tvm.sum(
-                ((tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
-                               kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) -
-                  tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1]&
-                               ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2),
-                axis=[dh, dw, ci, b1, b2])
-
-        return tvm.sum(tvm.popcount(
-            data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
-            kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2,
-                       axis=[dh, dw, ci, b1, b2])
-
-    conv = tvm.compute(ovshape, _conv, name='conv')
-
-    idxd = tvm.indexdiv
-    idxm = tvm.indexmod
-    return tvm.compute(
-        oshape, lambda n, h, w, co:
-        conv[n,
-             idxd(h, VH), idxd(w, VW), idxd(co, VC),
-             idxm(h, VH), idxm(w, VW), idxm(co, VC)],
-        name='output_unpack', tag='spatial_bitserial_conv_nhwc')
-
 @tvm.target.generic_func
 def bitserial_conv2d_legalize(attrs, inputs, types):
     """Legalizes Bitserial Conv2D op.
index d77a1b7..fa1b5df 100644 (file)
 """Bitserial Dense operator."""
 from __future__ import absolute_import
 import tvm
-from tvm import autotvm
 from topi.util import get_const_tuple
-from .bitserial_util import bitpack, binary_op_multiplier
+from .bitserial_util import bitpack
 
-@tvm.target.generic_func
 def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32',
                     out_dtype='int16', unipolar=True):
     """The default implementation of bitserial dense in topi.
@@ -66,78 +64,3 @@ def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32',
     if unipolar:
         return matmul_unipolar
     return matmul
-
-
-@autotvm.register_topi_compute(bitserial_dense, ['cpu'], 'direct')
-def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32',
-                            out_dtype='int16', unipolar=True):
-    """Bitserial dense implementation. TODO: Why are these separate
-
-    Parameters
-    ----------
-    data : tvm.Tensor
-        2-D with shape [batch, in_dim]
-    weight : tvm.Tensor
-        2-D with shape [out_dim, in_dim] or
-        3-D with shape [out_dim, weight_bits, in_dim]
-    Returns
-    -------
-    output : tvm.Tensor
-        2-D with shape [batch, out_dim]
-    """
-    data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
-    if len(weight.shape) == 2:
-        weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
-    else:
-        weight_packed = weight
-    Y, DB, K = get_const_tuple(data_packed.shape)
-    X, WB, _ = get_const_tuple(weight_packed.shape)
-    ######## Search space
-    x, y = cfg.axis(X), cfg.axis(Y)
-    db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
-    ko, ki = cfg.define_split('tile_k', k, num_outputs=2)
-    yo, yi = cfg.define_split('tile_y', y, num_outputs=2)
-    xo, xi = cfg.define_split('tile_x', x, num_outputs=2)
-
-    cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
-                       policy='candidate', candidate=[
-                           [yo, xo, ko, yi, wb, db, ki, xi],
-                           [yo, xo, yi, ko, wb, db, ki, xi]])
-
-    cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll')
-    cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec')
-
-    ###### Compute rule
-    VX = cfg['tile_x'].size[-1]
-
-    wvshape = (X//VX, WB, VX, K)
-    oshape = (Y, X)
-
-    k = tvm.reduce_axis((0, K), name='k')
-    db = tvm.reduce_axis((0, DB), name='db')
-    wb = tvm.reduce_axis((0, WB), name='wb')
-
-    # Tile data and weights
-    weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k:
-                             weight_packed[xo*VX+vx][wb][k], name='weight_vec')
-
-    idxdiv = tvm.indexdiv
-    idxmod = tvm.indexmod
-
-    matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
-        (tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) -
-         tvm.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])
-        ).astype(out_dtype)
-        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
-
-    matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
-        tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
-                    ).astype(out_dtype)
-        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
-
-    # binary ops
-    cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))
-
-    if unipolar:
-        return matmul_unipolar
-    return matmul
index 98fa2e3..4565fd2 100644 (file)
@@ -23,7 +23,6 @@ from ..util import simplify
 from .util import get_pad_tuple1d
 
 
-@tvm.target.generic_func
 def conv1d(data,
            kernel,
            strides=1,
@@ -101,6 +100,13 @@ def conv1d_ncw(data,
     out_dtype : str
         The output data type. If None then output is same type as input.
     """
+    if out_dtype is None:
+        out_dtype = data.dtype
+    if isinstance(strides, (tuple, list)):
+        strides = strides[0]
+    if isinstance(dilation, (tuple, list)):
+        dilation = dilation[0]
+
     batch, in_channels, data_width = data.shape
     out_channels, _, kernel_size = kernel.shape
 
@@ -158,6 +164,13 @@ def conv1d_nwc(data,
     out_dtype : str
         The output data type. If None then output is same type as input.
     """
+    if out_dtype is None:
+        out_dtype = data.dtype
+    if isinstance(strides, (tuple, list)):
+        strides = strides[0]
+    if isinstance(dilation, (tuple, list)):
+        dilation = dilation[0]
+
     batch, data_width, in_channels = data.shape
     kernel_size, _, out_channels = kernel.shape
 
index 39918e9..8d22424 100644 (file)
@@ -24,7 +24,6 @@ from ..util import simplify
 from .util import get_pad_tuple1d
 
 
-@tvm.target.generic_func
 def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype):
     """Transposed 1D convolution ncw forward operator.
 
index 52f4b12..a7a75ed 100644 (file)
@@ -23,7 +23,7 @@ import tvm
 
 from .pad import pad
 from .util import get_pad_tuple
-from ..util import simplify, get_const_tuple, get_const_int
+from ..util import simplify, get_const_tuple, get_const_int, tag
 from .winograd_util import winograd_transform_matrices
 
 # workload description of conv2d
@@ -31,7 +31,6 @@ Workload = namedtuple('Workload',
                       ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups',
                        'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
 
-@tvm.target.generic_func
 def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=None):
     """Conv2D operator.
 
@@ -96,7 +95,7 @@ def conv2d_legalize(attrs, inputs, types):
 
 
 @tvm.target.generic_func
-def conv2d_alter_layout(attrs, inputs, tinfos, F):
+def conv2d_alter_layout(attrs, inputs, tinfos, out_type):
     """Change Conv2D layout.
 
     Parameters
@@ -107,13 +106,12 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F):
         Grouped input symbols
     tinfos : list
         Input shape and dtype
-    F: symbol
-        The context, can be either relay.op
+    out_type: type
+        The output type
 
     Note
     ----
-    Unlike other TOPI functions, this function operates on both graph level and operator level,
-    so we have to pass 'F' to make it support our two versions of graph IR, Relay.
+    Unlike other TOPI functions, this function operates on both graph level and operator level.
     """
     # not to change by default
     return None
@@ -368,7 +366,6 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
     return Output
 
 
-@tvm.target.generic_func
 def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, out_dtype='float32'):
     """Conv2D operator for nChw[x]c layout.
 
@@ -408,58 +405,9 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
         5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
     """
 
-    return conv2d_NCHWc_compute(data,
-                                kernel,
-                                stride,
-                                padding,
-                                dilation,
-                                layout,
-                                out_layout,
-                                out_dtype)
-
-
-def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_layout, out_dtype):
-    """Conv2D operator compute for nChw[x]c layout.
-
-    Parameters
-    ----------
-    data : tvm.Tensor
-        5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
-
-    kernel : tvm.Tensor
-        6-D with shape
-        [num_filter_chunk, in_channel_chunk, filter_height, filter_width,
-        in_channel_block, num_filter_block]
-
-    stride : int or a list/tuple of two ints
-        stride size, or [stride_height, stride_width]
-
-    padding : int or a list/tuple of 2 or 4 ints
-        padding size, or
-        [pad_height, pad_width] for 2 ints, or
-        [pad_top, pad_left, pad_bottom, pad_right] for 4 ints
-
-    dilation: int or a list/tuple of two ints
-        dilation size, or [dilation_height, dilation_width]
-
-    layout : str
-        Input data layout
-
-    out_layout : str
-        Output data layout
-
-    out_dtype : str
-        output data type
-
-    Returns
-    -------
-    output : tvm.Tensor
-        5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
-    """
-
     # layout and out_layout are not used here,
     # we keep them for debug convenience when dumping autotvm workload
-    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+    HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride)
     dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
         else (dilation, dilation)
 
@@ -516,8 +464,7 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
                        name='conv2d_NCHWc', tag="conv2d_NCHWc")
 
 
-@tvm.target.generic_func
-def conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, layout, out_layout,
+def conv2d_NCHWc_int8(data, kernel, stride, padding, dilation, layout, out_layout,
                       out_dtype='int32'):
     """Conv2D operator for nChw[x]c layout.
 
@@ -557,59 +504,9 @@ def conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, layout, out_layo
         5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
     """
 
-    return conv2d_NCHWc_int8_compute(data,
-                                     kernel,
-                                     strides,
-                                     padding,
-                                     dilation,
-                                     layout,
-                                     out_layout,
-                                     out_dtype)
-
-
-def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, out_layout,
-                              out_dtype='int32'):
-    """Conv2D operator for nChw[x]c layout.
-
-    Parameters
-    ----------
-    data : tvm.Tensor
-        5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
-
-    kernel : tvm.Tensor
-        7-D with shape
-        [num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4,
-        num_filter_block, 4]
-
-    stride : int or a list/tuple of two ints
-        stride size, or [stride_height, stride_width]
-
-    padding : int or a list/tuple of 2 or 4 ints
-        padding size, or
-        [pad_height, pad_width] for 2 ints, or
-        [pad_top, pad_left, pad_bottom, pad_right] for 4 ints
-
-    dilation: int or a list/tuple of two ints
-        dilation size, or [dilation_height, dilation_width]
-
-    layout : str
-        Input data layout
-
-    out_layout : str
-        Output data layout
-
-    out_dtype : str
-        output data type
-
-    Returns
-    -------
-    output : tvm.Tensor
-        5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
-    """
-
     # layout and out_layout are not used here,
     # we keep them for debug convenience when dumping autotvm workload
-    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+    HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride)
     dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
         else (dilation, dilation)
 
@@ -724,33 +621,6 @@ def conv2d_winograd_weight_transform(kernel, tile_size):
                                axis=[r_kh, r_kw]), name='transform_weight')
 
 
-@tvm.target.generic_func
-def conv2d_winograd_without_weight_transform(input, filter, strides, padding, dilation,
-                                             layout, out_dtype, tile_size):
-    """Compute convolution in winograd algorithm. The filter is supposed to be transformed
-    in advance.
-
-    Parameters
-    ----------
-    input : tvm.Tensor
-        4-D with shape [batch, in_height, in_width, in_channel]
-    filter : tvm.Tensor
-        4-D with shape [filter_height, filter_width, in_channel, num_filter]
-    strides : int or a list/tuple of two ints
-        Stride size, or [stride_height, stride_width]
-    padding : int or str
-        Padding size, or ['VALID', 'SAME']
-    tile_size: int
-        Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
-
-    Returns
-    -------
-    output : tvm.Tensor
-        4-D with shape [batch, out_height, out_width, out_channel]
-    """
-    raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
-
-
 def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_dtype):
     """Weight transformation for winograd
      Parameters
@@ -769,32 +639,7 @@ def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_d
     return nnpack.convolution_inference_weight_transform(
         kernel, algorithm=convolution_algorithm, dtype=out_dtype)
 
-@tvm.target.generic_func
-def conv2d_winograd_nnpack_without_weight_transform(
-        input, filter, bias, strides, padding, dilation, layout, out_dtype):
-    """Compute convolution in winograd algorithm. The filter is supposed to be transformed
-    in advance.
-     Parameters
-    ----------
-    input : tvm.Tensor
-        4-D with shape [batch, in_height, in_width, in_channel]
-    filter : tvm.Tensor
-        4-D with shape [num_filter, in_channel, 8, 8]
-    bias : tvm.Tensor
-        1-D with shape [num_filter]
-    strides : int or a list/tuple of two ints
-        Stride size, or [stride_height, stride_width]
-    padding : int or str
-        Padding size, or ['VALID', 'SAME']
-     Returns
-    -------
-    output : tvm.Tensor
-        4-D with shape [batch, out_height, out_width, out_channel]
-    """
-    raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
-
 
-@tvm.target.generic_func
 def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None):
     """Group convolution operator in NCHW layout.
 
@@ -871,3 +716,35 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp
                  xx * stride_w + rx * dilation_w].astype(out_dtype) *
             Filter[ff, rc, ry, rx].astype(out_dtype),
             axis=[rc, ry, rx]), tag='group_conv2d_nchw')
+
+
+def unpack_NCHWc_to_nchw(packed_out, out_dtype):
+    """Unpack conv2d_NCHWc output from layout NCHWc to NCHW
+
+     Parameters
+    -----------
+    packed_out : tvm.Tensor
+        The output tensor of conv2d_NCHWc.
+
+    out_dtype : str
+        The output dtype.
+
+    Returns
+    -------
+    unpacked_out : tvm.Tensor
+        The unpacked output tensor in NCHW layout.
+    """
+    n, oc_chunk, oh, ow, oc_bn = get_const_tuple(packed_out.shape)
+
+    idxmod = tvm.indexmod
+    idxdiv = tvm.indexdiv
+
+    oshape = (n, oc_chunk * oc_bn, oh, ow)
+    unpacked_out = \
+        tvm.compute(oshape,
+                    lambda n, c, h, w:
+                    packed_out[n, idxdiv(c, oc_bn), h, w, idxmod(c, oc_bn)]
+                    .astype(out_dtype),
+                    name='output_unpack',
+                    tag=tag.INJECTIVE+",unpack_nchwc")
+    return unpacked_out
index e635f43..db132fc 100644 (file)
@@ -25,7 +25,6 @@ from .util import get_pad_tuple
 from ..util import simplify
 
 
-@tvm.target.generic_func
 def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
     """Transposed 2D convolution nchw forward operator.
 
index 83c16da..88c7c6a 100644 (file)
@@ -25,46 +25,8 @@ from .util import get_pad_tuple3d
 from ..util import simplify
 
 
-@tvm.target.generic_func
-def conv3d(input, filter, strides, padding, dilation, layout='NCDHW', out_dtype=None):
-    """Conv3D operator.
-
-    Parameters
-    ----------
-    input : tvm.Tensor
-        5-D with shape [batch, in_depth, in_channel, in_height, in_width]
-
-    filter : tvm.Tensor
-        5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
-
-    strides : int or a list/tuple of three ints
-        stride size, or [stride_depth, stride_height, stride_width]
-
-    padding : int or a list/tuple of three ints
-        padding size, or [pad_depth, pad_height, pad_width]
-
-    dilation: int or a list/tuple of three ints
-        dilation size, or [dilation_depth, dilation_height, dilation_width]
-
-    layout : str
-        layout of data
-
-    Returns
-    -------
-    output : tvm.Tensor
-        5-D with shape [batch, out_depth, out_channel, out_height, out_width]
-    """
-    # search platform specific declaration first
-    # default declaration
-    if layout == 'NCDHW':
-        return conv3d_ncdhw(input, filter, strides, padding, dilation, out_dtype)
-    elif layout == 'NDHWC':
-        return conv3d_ndhwc(input, filter, strides, padding, dilation, out_dtype)
-    raise ValueError("not support this layout {} yet".format(layout))
-
-
 def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None):
-    """Convolution operator in NCDHW layout.
+    """Conv3D operator in NCDHW layout.
 
     Parameters
     ----------
@@ -138,13 +100,13 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
     Parameters
     ----------
     Input : tvm.Tensor
-        5-D with shape [batch, in_channel, in_depth, in_height, in_width]
+        5-D with shape [batch, in_depth, in_height, in_width, in_channel]
 
     Filter : tvm.Tensor
-        5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
+        5-D with shape [filter_depth, filter_height, filter_width, in_channel, num_filter]
 
     stride : int or a list/tuple of three ints
-        Stride size, or [strid_depth, stride_height, stride_width]
+        Stride size, or [stride_depth, stride_height, stride_width]
 
     padding : int or str
         Padding size, or ['VALID', 'SAME']
@@ -155,7 +117,7 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
     Returns
     -------
     Output : tvm.Tensor
-        5-D with shape [batch, out_channel, out_depth, out_height, out_width]
+        5-D with shape [batch, out_depth, out_height, out_width, out_channel]
     """
     assert isinstance(stride, int) or len(stride) == 3
     assert isinstance(dilation, int) or len(dilation) == 3
index 2417411..251f68a 100644 (file)
@@ -22,7 +22,6 @@ from .util import get_pad_tuple
 from ..util import get_const_tuple
 from ..cpp.util import bilinear_sample_nchw
 
-@tvm.target.generic_func
 def deformable_conv2d_nchw(data, offset, kernel, strides, padding, dilation, deformable_groups,
                            groups, out_dtype):
     """Deformable conv2D operator in NCHW layout.
index 671b602..fe21e74 100644 (file)
@@ -19,7 +19,7 @@ from __future__ import absolute_import
 import tvm
 from .. import tag
 
-def dense_default(data, weight, bias=None, out_dtype=None):
+def dense(data, weight, bias=None, out_dtype=None):
     """The default implementation of dense in topi.
 
     Parameters
@@ -59,29 +59,3 @@ def dense_default(data, weight, bias=None, out_dtype=None):
                              lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), \
                              tag=tag.BROADCAST)
     return matmul
-
-
-@tvm.target.override_native_generic_func("dense")
-def dense(data, weight, bias=None, out_dtype=None):
-    """Applies a linear transformation: :math:`Y = XW^T + b`.
-
-    Parameters
-    ----------
-    data : tvm.Tensor
-        2-D with shape [batch, in_dim]
-
-    weight : tvm.Tensor
-        2-D with shape [out_dim, in_dim]
-
-    bias : tvm.Tensor, optional
-        1-D with shape [out_dim]
-
-    out_dtype : str
-        The output type. This is used for mixed precision.
-
-    Returns
-    -------
-    output : tvm.Tensor
-        2-D with shape [batch, out_dim]
-    """
-    return dense_default(data, weight, bias, out_dtype)
index f50e357..49aaace 100644 (file)
@@ -47,7 +47,6 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
                     out_channel, kh, kw, HPAD, WPAD, HSTR, WSTR)
 
 
-@tvm.target.generic_func
 def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
     """Depthwise convolution nchw forward operator.
 
@@ -121,7 +120,6 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No
     return Output
 
 
-@tvm.target.generic_func
 def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=None):
     """Depthwise convolution nhwc forward operator.
 
@@ -307,7 +305,6 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid
     return Weight_grad
 
 
-@tvm.target.generic_func
 def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
                            layout, out_layout, out_dtype=None):
     """Depthwise convolution NCHW[x]c forward operator.
index de002bf..1b41c7d 100644 (file)
 # pylint: disable=invalid-name
 """TVM operator for local response norm compute."""
 from __future__ import absolute_import
-import tvm
 from .. import cpp
 
-@tvm.target.generic_func
 def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2):
     """Perform the across channels local response normalisation
     on the input data.
index 584126e..6974ff4 100644 (file)
@@ -22,7 +22,6 @@ import tvm
 from ..util import get_const_tuple
 
 
-@tvm.target.generic_func
 def sparse_dense(data, weight_data, weight_indices, weight_indptr):
     """
     Computes sparse-dense matrix multiplication of `data` and
@@ -105,7 +104,7 @@ def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
         lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)],
         tag="sparse_dense_bsrmm")
 
-@tvm.target.generic_func
+
 def sparse_transpose(sparse_data, sparse_indices, sparse_indptr):
     """
     Transpose a square sparse matrix,
@@ -148,14 +147,15 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr):
         shape=output_shape,
         inputs=[sparse_data, sparse_indices, sparse_indptr],
         fcompute=lambda ins, outs:
-        csr_transpose_ir(ins[0], ins[1], ins[2], outs[0], outs[1], outs[2]),
+        _csr_transpose_ir(ins[0], ins[1], ins[2], outs[0], outs[1], outs[2]),
         tag="sparse_transpose_csr",
         dtype=['float32', 'int32', 'int32'],
         name='out')
 
     return [output_data, output_indices, output_indptr]
 
-def csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr):
+
+def _csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr):
     """define ir for csr_transpose"""
     irb = tvm.ir_builder.create()
 
index aa73e84..f0cdd9a 100644 (file)
@@ -143,7 +143,7 @@ def get_pad_tuple(padding, kernel):
             pad_h = padding[0] * 2
             pad_w = padding[1] * 2
         elif len(padding) == 4:
-            return  padding[0], padding[1], padding[2], padding[3]
+            return padding[0], padding[1], padding[2], padding[3]
         else:
             raise ValueError("Size of padding can only be 2 or 4")
     elif isinstance(padding, int):
index e39d1ad..52ed119 100644 (file)
@@ -18,9 +18,7 @@
 """Schedule for conv2d_nchw with auto fusion"""
 import tvm
 from .. import tag
-from .. import generic
 
-@generic.schedule_conv2d_nchw.register(["opengl"])
 def schedule_conv2d_nchw(outs):
     """Schedule for conv2d_nchw.
 
index c93dfcc..db2c4a6 100644 (file)
@@ -19,9 +19,7 @@
 from __future__ import absolute_import as _abs
 import tvm
 from .. import tag
-from .. import generic
 
-@generic.schedule_dense.register(["opengl"])
 def schedule_dense(outs):
     """Schedule for dense operator.
 
index d3ebc94..28dc87d 100644 (file)
@@ -17,9 +17,7 @@
 # pylint: disable=invalid-name, unused-variable,
 """Schedule for composition of injective operator"""
 import tvm
-from .. import generic
 
-@generic.schedule_injective_from_existing.register(["opengl"])
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -38,7 +36,6 @@ def schedule_injective_from_existing(sch, out):
     sch[out].opengl()
     return sch
 
-@generic.schedule_injective.register(["opengl"])
 def schedule_injective(outs):
     """Schedule for injective op.
 
index 04c7b0c..3226422 100644 (file)
@@ -18,9 +18,7 @@
 """Schedule for pooling operators"""
 import tvm
 from .. import tag
-from .. import generic
 
-@generic.schedule_adaptive_pool.register(["opengl"])
 def schedule_adaptive_pool(outs):
     """Schedule for adaptive pool.
 
@@ -69,7 +67,6 @@ def schedule_adaptive_pool(outs):
     return s
 
 
-@generic.schedule_pool.register(["opengl"])
 def schedule_pool(outs, layout):
     """Schedule for pool.
 
index e343d45..ff218d1 100644 (file)
@@ -17,9 +17,7 @@
 # pylint: disable=invalid-name, unused-variable, trailing-whitespace
 """Schedule for softmax operator"""
 import tvm
-from .. import generic
 
-@generic.schedule_softmax.register(["opengl"])
 def schedule_softmax(outs):
     """Schedule for softmax op.
 
index be29c6f..ce56dc4 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, unused-argument
 """Compute definition for conv2d with rocm backend"""
-import tvm
 from tvm import autotvm
 from tvm.contrib import miopen
 
-from .. import nn, generic
+from .. import generic
 from ..util import get_const_tuple
-from ..cuda.conv2d import conv2d_cuda, schedule_conv2d_nchw_cuda
 from ..nn.util import get_pad_tuple
 
-@autotvm.register_topi_compute(nn.conv2d, 'rocm', ['direct', 'winograd'])
-def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', out_dtype='float32'):
+@autotvm.register_topi_compute("conv2d_nchw_miopen.rocm")
+def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
     """Conv2D operator for rocm backend.
 
     Parameters
@@ -57,39 +55,34 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
         4-D with shape [batch, out_channel, out_height, out_width]
     """
 
-    target = tvm.target.Target.current()
-    if "miopen" in target.libs:
-        assert layout == 'NCHW', "Only NCHW layout is supported."
-        CO, CI, KH, KW = get_const_tuple(kernel.shape)
-        N, _, H, W = get_const_tuple(data.shape)
-
-        # handle dilation
-        stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
-        pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
-        pad_h, pad_w = pt + pb, pl + pr
-        dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
-
-        OH = (H + 2 * pad_h - KH) // stride_h + 1
-        OW = (W + 2 * pad_w - KW) // stride_w + 1
-        cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
-                    ((KW - 1) * dilation_w + 1))
-
-        return miopen.conv2d_forward(data,
-                                     kernel,
-                                     stride_h,
-                                     stride_w,
-                                     pad_h,
-                                     pad_w,
-                                     dilation_h,
-                                     dilation_w,
-                                     conv_mode=0,
-                                     data_type=1)
-
-    return conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
-
-
-@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, 'rocm', ["direct", 'winograd'])
-def schedule_conv2d_nchw_rocm(cfg, outs):
+    CO, CI, KH, KW = get_const_tuple(kernel.shape)
+    N, _, H, W = get_const_tuple(data.shape)
+
+    # handle dilation
+    stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
+    pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
+    pad_h, pad_w = pt + pb, pl + pr
+    dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
+
+    OH = (H + 2 * pad_h - KH) // stride_h + 1
+    OW = (W + 2 * pad_w - KW) // stride_w + 1
+    cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
+                ((KW - 1) * dilation_w + 1))
+
+    return miopen.conv2d_forward(data,
+                                 kernel,
+                                 stride_h,
+                                 stride_w,
+                                 pad_h,
+                                 pad_w,
+                                 dilation_h,
+                                 dilation_w,
+                                 conv_mode=0,
+                                 data_type=1)
+
+
+@autotvm.register_topi_schedule("conv2d_nchw_miopen.rocm")
+def schedule_conv2d_nchw_miopen(cfg, outs):
     """TOPI schedule callback of conv2d for rocm
 
     Parameters
@@ -106,8 +99,4 @@ def schedule_conv2d_nchw_rocm(cfg, outs):
     s: Schedule
         The computation schedule for conv2d.
     """
-    target = tvm.target.Target.current()
-    if target and "miopen" in target.libs:
-        return generic.schedule_extern(outs)
-
-    return schedule_conv2d_nchw_cuda(cfg, outs)
+    return generic.schedule_extern(outs)
index f2adeaa..8729a62 100644 (file)
@@ -20,13 +20,12 @@ from __future__ import absolute_import as _abs
 import tvm
 from tvm import autotvm
 from tvm.contrib import rocblas
-import topi
-from ..nn.dense import dense, dense_default
+from .. import generic, nn
 from .. import tag
-from .. import generic
+from ..util import traverse_inline
 
-@autotvm.register_topi_compute(dense, "rocm", "direct")
-def dense_rocm(cfg, data, weight, bias=None, out_dtype=None):
+@autotvm.register_topi_compute('dense.rocm')
+def dense(cfg, data, weight, bias=None, out_dtype=None):
     """Dense operator for rocm backend.
 
     Parameters
@@ -54,21 +53,10 @@ def dense_rocm(cfg, data, weight, bias=None, out_dtype=None):
         assert len(bias.shape) == 1
     if out_dtype is None:
         out_dtype = data.dtype
-    batch, in_dim = data.shape
-    out_dim, _ = weight.shape
-    target = tvm.target.Target.current()
-    if "rocblas" in target.libs:
-        assert out_dtype == data.dtype, "Mixed precision not supported."
-        matmul = rocblas.matmul(data, weight, False, True)
-        if bias is not None:
-            matmul = tvm.compute((batch, out_dim), \
-                                 lambda i, j: matmul[i, j] + bias[j], \
-                                 tag=tag.BROADCAST)
-        return matmul
-    return dense_default(data, weight, bias, out_dtype)
-
-
-@autotvm.register_topi_schedule(generic.schedule_dense, "rocm", "direct")
+    return nn.dense(data, weight, bias, out_dtype)
+
+
+@autotvm.register_topi_schedule('dense.rocm')
 def schedule_dense(cfg, outs):
     """Schedule for dense operator.
 
@@ -83,7 +71,72 @@ def schedule_dense(cfg, outs):
     s: Schedule
         The computation schedule for dense.
     """
-    target = tvm.target.Target.current()
-    if target.target_name == "rocm" and "rocblas" in target.libs:
-        return generic.schedule_extern(outs)
-    return topi.cuda.schedule_dense(cfg, outs)
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == 'dense':
+            Dense = op.output(0)
+            num_thread = 64
+            k = Dense.op.reduce_axis[0]
+            ko, kf = s[Dense].split(k, factor=num_thread)
+            DenseF = s.rfactor(Dense, kf)
+
+            if Dense.op in s.outputs:
+                Out = Dense
+            else:
+                Out = outs[0].op.output(0)
+                s[Dense].compute_at(s[Out], s[Out].op.axis[1])
+            s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
+            s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))
+
+            tx = s[Dense].op.reduce_axis[0]
+            thread_x = tvm.thread_axis("threadIdx.x")
+            s[Dense].bind(tx, thread_x)
+            s[DenseF].compute_at(s[Dense], tx)
+            s[Dense].set_store_predicate(thread_x.var.equal(0))
+            s[Out].set_store_predicate(thread_x.var.equal(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+@autotvm.register_topi_compute('dense_rocblas.rocm')
+def dense_rocblas(cfg, data, weight, bias=None, out_dtype=None):
+    """Dense operator for rocm backend with cblas.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        2-D with shape [batch, in_dim]
+
+    weight : tvm.Tensor
+        2-D with shape [out_dim, in_dim]
+
+    bias : tvm.Tensor, optional
+        1-D with shape [out_dim]
+
+    out_dtype : str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.Tensor
+        2-D with shape [batch, out_dim]
+    """
+    assert out_dtype == data.dtype, "Mixed precision not supported."
+    matmul = rocblas.matmul(data, weight, False, True)
+    batch, in_dim = data.shape
+    out_dim, _ = weight.shape
+    cfg.add_flop(batch * in_dim * out_dim * 2)
+    if bias is not None:
+        matmul = tvm.compute((batch, out_dim),
+                             lambda i, j: matmul[i, j] + bias[j],
+                             tag=tag.BROADCAST)
+    return matmul
+
+
+@autotvm.register_topi_schedule('dense_rocblas.rocm')
+def schedule_dense_rocblas(_, outs):
+    """Schedule for dense operator with rocm cblas"""
+    return generic.schedule_extern(outs)
index 8a9c8c3..5f134cb 100644 (file)
 """scheduler for normalization functions on rocm backend"""
 from __future__ import absolute_import as _abs
 
-import tvm
-from .. import generic
 from .. import cpp
 
-@generic.schedule_lrn.register(["rocm", "gpu"])
 def schedule_lrn(outs):
-    target = tvm.target.Target.current(allow_none=False)
-    cpp_target = cpp.TEST_create_target(target.target_name)
-    return cpp.rocm.schedule_lrn(cpp_target, outs)
+    return cpp.rocm.schedule_lrn(outs)
index 22899c4..96a0889 100644 (file)
@@ -20,7 +20,6 @@ import tvm
 from tvm import api
 from .util import get_const_tuple
 
-@tvm.target.generic_func
 def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array
     of indices having the same shape as an input array that index
@@ -99,7 +98,6 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     return out
 
 
-@tvm.target.generic_func
 def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
     """Get the top k elements in an input tensor along the given axis.
 
index 91b7dc5..b0f4752 100644 (file)
@@ -53,3 +53,5 @@ from .one_hot import one_hot
 from .depth_to_space import depth_to_space_python
 from .space_to_depth import space_to_depth_python
 from .crop_and_resize_python import crop_and_resize_python
+from .common import get_injective_schedule, get_reduce_schedule, get_broadcast_schedule, \
+    get_elemwise_schedule, get_conv2d_nchw_implement, dispatch
diff --git a/topi/python/topi/testing/common.py b/topi/python/topi/testing/common.py
new file mode 100644 (file)
index 0000000..5817513
--- /dev/null
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Common utility for topi test"""
+
+import tvm
+import topi
+
+_injective_schedule = {
+    "generic": topi.generic.schedule_injective,
+    "cpu": topi.x86.schedule_injective,
+    "arm_cpu": topi.arm_cpu.schedule_injective,
+    "gpu": topi.cuda.schedule_injective,
+    "hls": topi.hls.schedule_injective,
+    "opengl": topi.opengl.schedule_injective
+}
+
+_reduce_schedule = {
+    "generic": topi.generic.schedule_reduce,
+    "cpu": topi.x86.schedule_reduce,
+    "gpu": topi.cuda.schedule_reduce,
+    "hls": topi.cuda.schedule_reduce
+}
+
+def dispatch(target, dispatch_map):
+    if isinstance(target, str):
+        target = tvm.target.create(target)
+    assert isinstance(target, tvm.target.Target)
+    for key in target.keys:
+        if key in dispatch_map:
+            return dispatch_map[key]
+    return dispatch_map["generic"]
+
+def get_injective_schedule(target):
+    return dispatch(target, _injective_schedule)
+
+def get_reduce_schedule(target):
+    return dispatch(target, _reduce_schedule)
+
+get_broadcast_schedule = get_injective_schedule
+get_elemwise_schedule = get_injective_schedule
+
+_conv2d_nchw_implement = {
+    "generic": (topi.nn.conv2d_nchw, topi.generic.schedule_conv2d_nchw),
+    "cpu": (topi.x86.conv2d_nchw, topi.x86.schedule_conv2d_nchw),
+    "arm_cpu": (topi.arm_cpu.conv2d_nchw_spatial_pack,
+                topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
+    "gpu": (topi.cuda.conv2d_nchw, topi.cuda.schedule_conv2d_nchw),
+    "mali": (topi.mali.conv2d_nchw_spatial_pack,
+             topi.mali.schedule_conv2d_nchw_spatial_pack),
+    "bifrost": (topi.bifrost.conv2d_nchw_spatial_pack,
+                topi.bifrost.schedule_conv2d_nchw_spatial_pack),
+    "opengl": (topi.nn.conv2d_nchw, topi.opengl.schedule_conv2d_nchw),
+    "intel_graphics": (topi.intel_graphics.conv2d_nchw,
+                       topi.intel_graphics.schedule_conv2d_nchw),
+    "hls": (topi.nn.conv2d_nchw, topi.hls.schedule_conv2d_nchw)
+}
+
+def get_conv2d_nchw_implement(target):
+    return dispatch(target, _conv2d_nchw_implement)
index 5bb36f7..c171f8c 100644 (file)
@@ -116,7 +116,7 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, one):
                     out_tensor[i, j, k] = -one
     return valid_count, out_tensor
 
-@tvm.target.generic_func
+
 def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
     """Get valid count of bounding boxes given a score threshold.
     Also moves valid boxes to the top of input data.
@@ -289,7 +289,6 @@ def hybrid_nms(data, sorted_index, valid_count,
     return output, box_indices
 
 
-@tvm.target.generic_func
 def non_max_suppression(data, valid_count, max_output_size=-1,
                         iou_threshold=0.5, force_suppress=False, top_k=-1,
                         coord_start=2, score_index=1, id_index=0,
index d48c890..5de4998 100644 (file)
@@ -317,7 +317,7 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
     body = ib.get()
     return body
 
-@tvm.target.generic_func
+
 def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
              rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
     """Proposal operator.
index a6540b3..a0bc5e2 100644 (file)
@@ -21,7 +21,6 @@ from ...util import get_const_tuple
 from ...cpp.util import bilinear_sample_nchw
 
 
-@tvm.target.generic_func
 def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
     """ROI align operator in NCHW layout.
 
index 53ffe35..f346f58 100644 (file)
@@ -19,7 +19,6 @@
 import tvm
 from ...util import get_const_tuple
 
-@tvm.target.generic_func
 def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
     """ROI pool operator in NCHW layout.
 
index 7adfc73..3ba5e84 100644 (file)
@@ -20,10 +20,8 @@ REORG Operator
 Reorg operator, used in darknet.
 """
 from __future__ import absolute_import as _abs
-import tvm
 from .. import cpp
 
-@tvm.target.generic_func
 def reorg(data, stride):
     """Reorg forward operators.
 
index 8c31f82..4309af4 100644 (file)
@@ -89,7 +89,6 @@ def hybrid_multibox_prior(data, sizes, ratios, steps, offsets):
     return output
 
 
-@tvm.target.generic_func
 def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False):
     """Generate prior(anchor) boxes from data, sizes and ratios.
 
@@ -233,7 +232,6 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
 
     return out_loc, valid_count
 
-@tvm.target.generic_func
 def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
                            variances=(0.1, 0.1, 0.2, 0.2)):
     """Location transformation for multibox detection
@@ -267,7 +265,6 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
                                          tvm.const(threshold, "float32"),
                                          tvm.convert(variances))
 
-@tvm.target.generic_func
 def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5,
                        force_suppress=False, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=-1):
     """Convert multibox detection predictions.
index d1c728d..ce07c19 100644 (file)
@@ -19,9 +19,9 @@
 """x86 specific declaration and schedules."""
 from __future__ import absolute_import as _abs
 
-from .conv1d import schedule_conv1d_nwc
-from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
-from .conv3d import schedule_conv3d_ndhwc
+from .conv1d import *
+from .conv2d import *
+from .conv3d import *
 from .binarize_pack import schedule_binarize_pack
 from .binary_dense import schedule_binary_dense
 from .nn import *
@@ -29,12 +29,12 @@ from .conv2d_int8 import *
 from .injective import *
 from .reduction import *
 from .pooling import schedule_pool, schedule_adaptive_pool
-from .bitserial_conv2d import schedule_bitserial_conv2d
-from .bitserial_dense import schedule_bitserial_dense
-from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc
-from .dense import _schedule_dense, _schedule_dense_pack, _schedule_dense_nopack
-from .batch_matmul import schedule_batch_matmul
+from .bitserial_conv2d import *
+from .bitserial_dense import *
+from .depthwise_conv2d import *
+from .dense import *
+from .batch_matmul import *
 from .roi_align import roi_align_nchw
-from .conv2d_transpose import _schedule_conv2d_transpose_nchw
+from .conv2d_transpose import *
 from .sparse import *
 from .conv2d_alter_op import *
index fef6c48..a7cb9e9 100644 (file)
@@ -21,12 +21,12 @@ import tvm
 from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity
 from tvm.contrib import cblas
-from .. import generic, nn
+from .. import generic
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
 
-@autotvm.register_topi_compute(nn.batch_matmul, "cpu", "direct")
-def _declaration_batch_matmul_nopack(cfg, x, y):
+@autotvm.register_topi_compute("batch_matmul.x86")
+def batch_matmul(cfg, x, y):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
     data in batch.
 
@@ -43,10 +43,6 @@ def _declaration_batch_matmul_nopack(cfg, x, y):
     output : tvm.Tensor
         3-D with shape [batch, M, N]
     """
-    target = tvm.target.Target.current()
-    if "cblas" in target.libs:
-        return cblas.batch_matmul(x, y, False, True)
-
     assert len(x.shape) == 3 and len(
         y.shape) == 3, "only support 3-dim batch_matmul"
     XB, M, XK = get_const_tuple(x.shape)
@@ -56,7 +52,7 @@ def _declaration_batch_matmul_nopack(cfg, x, y):
     B = XB
     K = XK
     if cfg.is_fallback:
-        _default_batch_matmul_nopack_config(cfg, M, N, K)
+        _default_batch_matmul_config(cfg, M, N, K)
 
     k = tvm.reduce_axis((0, K), name='k')
     C = tvm.compute(
@@ -66,7 +62,7 @@ def _declaration_batch_matmul_nopack(cfg, x, y):
     return C
 
 
-@autotvm.register_topi_schedule(generic.schedule_batch_matmul, "cpu", "direct")
+@autotvm.register_topi_schedule("batch_matmul.x86")
 def schedule_batch_matmul(cfg, outs):
     """Schedule for batch_matmul
 
@@ -83,10 +79,6 @@ def schedule_batch_matmul(cfg, outs):
     sch: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.Target.current()
-    if "cblas" in target.libs:
-        return generic.schedule_extern(outs)
-
     s = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
@@ -131,9 +123,42 @@ def schedule_batch_matmul(cfg, outs):
     return s
 
 
-def _default_batch_matmul_nopack_config(cfg, M, N, K):
+def _default_batch_matmul_config(cfg, M, N, K):
     cfg["tile_k"] = SplitEntity([K // 16, 16])
     x_bn = get_max_power2_factor(N, 8)
     cfg["tile_x"] = SplitEntity([N // x_bn, x_bn])
     y_bn = get_max_power2_factor(M, 8)
     cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])
+
+
+@autotvm.register_topi_compute("batch_matmul_cblas.x86")
+def batch_matmul_cblas(cfg, x, y):
+    """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+    data in batch.
+
+    Parameters
+    ----------
+    cfg : ConfigSpace
+        Autotvm tuning space config file
+    x : tvm.Tensor
+        3-D with shape [batch, M, K]
+    y : tvm.Tensor
+        3-D with shape [batch, N, K]
+    Returns
+    -------
+    output : tvm.Tensor
+        3-D with shape [batch, M, N]
+    """
+    assert len(x.shape) == 3 and len(
+        y.shape) == 3, "only support 3-dim batch_matmul"
+    XB, M, XK = get_const_tuple(x.shape)
+    YB, N, YK = get_const_tuple(y.shape)
+    assert XB == YB, "batch dimension doesn't match"
+    assert XK == YK, "shapes of x and y is inconsistant"
+    cfg.add_flop(XB * M * N * XK * 2)
+    return cblas.batch_matmul(x, y, False, True)
+
+
+@autotvm.register_topi_schedule("batch_matmul_cblas.x86")
+def schedule_batch_matmul_cblas(_, outs):
+    return generic.schedule_extern(outs)
index ea2bbed..bab91a9 100644 (file)
 """Schedule for binarization and bit-packing."""
 from __future__ import absolute_import as _abs
 import tvm
-from .. import generic
 
 
-@generic.schedule_binarize_pack.register(["cpu"])
 def schedule_binarize_pack(outs):
     """Schedule for binarize_pack.
 
index abf0908..ccf74e7 100644 (file)
 from __future__ import absolute_import as _abs
 import tvm
 from .. import tag
-from .. import generic
 
 
-@generic.schedule_binary_dense.register(["cpu"])
 def schedule_binary_dense(outs):
     """Schedule for binary_dense.
 
index 97d0dc0..2ec5653 100644 (file)
 """Bitserial conv2d schedule on x86"""
 import tvm
 from tvm import autotvm
-from topi.util import get_const_int
-from .. import generic, tag
+from .. import tag
+from ..util import get_const_int, get_const_tuple
+from ..nn.pad import pad
+from ..nn.util import get_pad_tuple
+from ..nn.bitserial_util import bitpack, binary_op_multiplier
+
+@autotvm.register_topi_compute("bitserial_conv2d_nchw.x86")
+def bitserial_conv2d_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
+                          pack_dtype='uint32', out_dtype='int16', unipolar=True):
+    """ Compute convolution with pack on spatial axes. """
+    assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
+    data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
+    # Check if kernel is already bitpacked
+    if len(kernel.shape) == 4:
+        kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype)
+        KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape)
+    else:
+        kernel_vec = kernel
+        OCO, _, KH, KW, KB, VC = get_const_tuple(kernel_vec.shape)
+        CO = OCO * VC
+
+    IB, N, CI, H, W = get_const_tuple(data_q.shape)
+    KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape)
+
+    if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
+        TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
+    else:
+        TPAD, LPAD, DPAD, RPAD = padding
+    pad_before = [0, 0, 0, TPAD, LPAD]
+    pad_after = [0, 0, 0, DPAD, RPAD]
 
-@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nchw, ['cpu'], 'direct')
-@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, ['cpu'], 'direct')
-def schedule_bitserial_conv2d(cfg, outs):
+    if isinstance(stride, (tuple, list)):
+        HSTR, WSTR = stride
+    else:
+        HSTR, WSTR = stride, stride
+    HCAT, WCAT = KH-1, KW-1
+
+    TH = H + TPAD + DPAD
+    TW = W + LPAD + RPAD
+    OH = (H + TPAD + DPAD - KH) // HSTR + 1
+    OW = (W + LPAD + RPAD - KW) // WSTR + 1
+
+    # ==================== define configuration space ====================
+    n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
+    ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
+    ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
+
+    co, vc = cfg.define_split('tile_co', co, num_outputs=2,
+                              filter=lambda x: max(x.size[1:]) <= 16)
+    oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
+                              filter=lambda x: max(x.size[1:]) <= 16)
+    ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
+                              filter=lambda x: max(x.size[1:]) <= 16)
+    cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
+
+    cfg.define_reorder("reorder_0",
+                       [n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci],
+                       policy='interval_all', interval=(6, 11))
+    # binary ops
+    cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
+    # ====================
+
+    VC = cfg["tile_co"].size[-1]
+    VH = cfg["tile_oh"].size[-1]
+    VW = cfg["tile_ow"].size[-1]
+
+    dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB)
+    kvshape = (CO//VC, CI, KH, KW, KB, VC)
+    ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC)
+    oshape = (1, CO, OH, OW)
+
+    if (TPAD != 0 and RPAD != 0):
+        data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
+    else:
+        data_pad = data_q
+
+    data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \
+        data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec')
+
+    if len(kernel.shape) == 4:
+        kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \
+            kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec')
+
+    ci = tvm.reduce_axis((0, CI), name='ci')
+    dh = tvm.reduce_axis((0, KH), name='dh')
+    dw = tvm.reduce_axis((0, KW), name='dw')
+    b1 = tvm.reduce_axis((0, IB), name='ib')
+    b2 = tvm.reduce_axis((0, KB), name='kb')
+
+    def _conv(n, co, h, w, vh, vw, vc):
+        b1b2 = (b1+b2).astype(out_dtype)
+        if unipolar:
+            return tvm.sum((tvm.popcount(
+                data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) &
+                kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype))  -
+                            tvm.popcount(
+                                data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype)
+                                & ~kernel_vec[co, ci, dh, dw, b2, vc]).astype(out_dtype)) << b1b2,
+                           axis=[ci, dh, dw, b1, b2])
+
+        return tvm.sum((tvm.popcount(
+            data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1] &
+            kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2,
+                       axis=[ci, dh, dw, b1, b2])
+
+    conv = tvm.compute(ovshape, _conv, name='conv_out')
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+
+    return tvm.compute(
+        oshape, lambda n, co, h, w:
+        conv[n,
+             idxd(co, VC), idxd(h, VH), idxd(w, VW),
+             idxm(h, VH), idxm(w, VW), idxm(co, VC)],
+        name='conv_vec', tag='spatial_bitserial_conv_nchw')
+
+@autotvm.register_topi_compute("bitserial_conv2d_nhwc.x86")
+def bitserial_conv2d_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
+                          pack_dtype='uint32', out_dtype='int16', unipolar=True):
+    """ Compute convolution with pack on spatial axes. """
+    assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
+    data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype)
+    pack_kernel = len(kernel.shape) == 4
+
+    if pack_kernel:
+        kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype)
+    else:
+        kernel_q = kernel
+
+    KH, KW, _, CO, KB = get_const_tuple(kernel_q.shape)
+    N, H, W, CI, IB = get_const_tuple(data_q.shape)
+
+    if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2):
+        TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel)
+    else:
+        TPAD, LPAD, DPAD, RPAD = padding
+    pad_before = [0, TPAD, LPAD, 0, 0]
+    pad_after = [0, DPAD, RPAD, 0, 0]
+
+    if isinstance(stride, (tuple, list)):
+        HSTR, WSTR = stride
+    else:
+        HSTR, WSTR = stride, stride
+    HCAT, WCAT = KH-1, KW-1
+
+    PAD_H = H + (TPAD + DPAD)
+    PAD_W = W + (LPAD + RPAD)
+    OH = (PAD_H - KH) // HSTR + 1
+    OW = (PAD_W - KW) // WSTR + 1
+    oshape = (1, OH, OW, CO)
+
+    # ==================== define configuration space ====================
+    n, oh, ow, co = cfg.axis(N), cfg.axis(OH), cfg.axis(OW), cfg.axis(CO)
+    ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
+    ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)
+
+    co, vc = cfg.define_split('tile_co', co, num_outputs=2,
+                              filter=lambda x: max(x.size[1:]) <= 16)
+    oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
+                              filter=lambda x: max(x.size[1:]) <= 16)
+    ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
+                              filter=lambda x: max(x.size[1:]) <= 16)
+    cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
+    cfg.define_reorder("reorder_0",
+                       [n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci],
+                       policy='interval_all', interval=(3, 7))
+    # binary ops
+    cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW * binary_op_multiplier(pack_dtype))
+    # ====================
+
+    VC = cfg["tile_co"].size[-1]
+    VH = cfg["tile_oh"].size[-1]
+    VW = cfg["tile_ow"].size[-1]
+
+    dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB)
+    kvshape = (CO, KH, KW, CI, VC, KB)
+    ovshape = (1, OH, OW, CO, VH, VW, VC)
+    oshape = (1, OH, OW, CO)
+
+    if (DPAD != 0 and RPAD != 0):
+        data_pad = pad(data_q, pad_before, pad_after, name="data_pad")
+    else:
+        data_pad = data_q
+
+    data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, ci, b: \
+        data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][ci][b], name='data_vec')
+
+    kernel_vec = tvm.compute(kvshape, lambda co, dh, dw, ci, vc, b: \
+        kernel_q[dh][dw][ci][co*VC+vc][b], name='kernel_vec')
+
+    ci = tvm.reduce_axis((0, CI), name='ci')
+    dh = tvm.reduce_axis((0, KH), name='dh')
+    dw = tvm.reduce_axis((0, KW), name='dw')
+    b1 = tvm.reduce_axis((0, IB), name='ib')
+    b2 = tvm.reduce_axis((0, KB), name='kb')
+
+    def _conv(n, h, w, co, vh, vw, vc):
+        b1b2 = (b1+b2).astype(out_dtype)
+        if unipolar:
+            return tvm.sum(
+                ((tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
+                               kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) -
+                  tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1]&
+                               ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2),
+                axis=[dh, dw, ci, b1, b2])
+
+        return tvm.sum(tvm.popcount(
+            data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] &
+            kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2,
+                       axis=[dh, dw, ci, b1, b2])
+
+    conv = tvm.compute(ovshape, _conv, name='conv')
+
+    idxd = tvm.indexdiv
+    idxm = tvm.indexmod
+    return tvm.compute(
+        oshape, lambda n, h, w, co:
+        conv[n,
+             idxd(h, VH), idxd(w, VW), idxd(co, VC),
+             idxm(h, VH), idxm(w, VW), idxm(co, VC)],
+        name='output_unpack', tag='spatial_bitserial_conv_nhwc')
+
+@autotvm.register_topi_schedule("bitserial_conv2d_nchw.x86")
+def schedule_bitserial_conv2d_nchw(cfg, outs):
+    return _schedule_bitserial_conv2d(cfg, outs)
+
+@autotvm.register_topi_schedule("bitserial_conv2d_nhwc.x86")
+def schedule_bitserial_conv2d_nhwc(cfg, outs):
+    return _schedule_bitserial_conv2d(cfg, outs)
+
+def _schedule_bitserial_conv2d(cfg, outs):
     """CPU schedule for bitserial convolutions NCHW and NHWC"""
     s = tvm.create_schedule([x.op for x in outs])
     scheduled_ops = []
index 47b972f..d464cae 100644 (file)
 from __future__ import absolute_import as _abs
 import tvm
 from tvm import autotvm
-from topi.util import get_const_int
+from topi.util import get_const_int, get_const_tuple
 from .. import tag
-from .. import generic
+from ..nn.bitserial_util import bitpack, binary_op_multiplier
 
-@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['cpu'], 'direct')
+@autotvm.register_topi_compute('bitserial_dense.x86')
+def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32',
+                    out_dtype='int16', unipolar=True):
+    """Bitserial dense implementation. TODO: Why are these separate
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        2-D with shape [batch, in_dim]
+    weight : tvm.Tensor
+        2-D with shape [out_dim, in_dim] or
+        3-D with shape [out_dim, weight_bits, in_dim]
+    Returns
+    -------
+    output : tvm.Tensor
+        2-D with shape [batch, out_dim]
+    """
+    data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
+    if len(weight.shape) == 2:
+        weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
+    else:
+        weight_packed = weight
+    Y, DB, K = get_const_tuple(data_packed.shape)
+    X, WB, _ = get_const_tuple(weight_packed.shape)
+    ######## Search space
+    x, y = cfg.axis(X), cfg.axis(Y)
+    db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
+    ko, ki = cfg.define_split('tile_k', k, num_outputs=2)
+    yo, yi = cfg.define_split('tile_y', y, num_outputs=2)
+    xo, xi = cfg.define_split('tile_x', x, num_outputs=2)
+
+    cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
+                       policy='candidate', candidate=[
+                           [yo, xo, ko, yi, wb, db, ki, xi],
+                           [yo, xo, yi, ko, wb, db, ki, xi]])
+
+    cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll')
+    cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec')
+
+    ###### Compute rule
+    VX = cfg['tile_x'].size[-1]
+
+    wvshape = (X//VX, WB, VX, K)
+    oshape = (Y, X)
+
+    k = tvm.reduce_axis((0, K), name='k')
+    db = tvm.reduce_axis((0, DB), name='db')
+    wb = tvm.reduce_axis((0, WB), name='wb')
+
+    # Tile data and weights
+    weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k:
+                             weight_packed[xo*VX+vx][wb][k], name='weight_vec')
+
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+
+    matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
+        (tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) -
+         tvm.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])
+        ).astype(out_dtype)
+        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')
+
+    matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
+        tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
+                    ).astype(out_dtype)
+        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')
+
+    # binary ops
+    cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))
+
+    if unipolar:
+        return matmul_unipolar
+    return matmul
+
+@autotvm.register_topi_schedule('biserial_dense.x86')
 def schedule_bitserial_dense(cfg, outs):
     """Schedule for bitserial_dense.
 
index 95fd159..70c2a68 100644 (file)
 """Conv1D schedule on for Intel CPU"""
 from __future__ import absolute_import as _abs
 import tvm
-from .. import generic, tag
+from .. import tag
 
 
-@generic.schedule_conv1d_ncw.register(["cpu"])
 def schedule_conv1d_ncw(outs):
     """Create schedule for tensors"""
     s = tvm.create_schedule([x.op for x in outs])
@@ -76,7 +75,6 @@ def schedule_conv1d_ncw(outs):
     return s
 
 
-@generic.schedule_conv1d_nwc.register(["cpu"])
 def schedule_conv1d_nwc(outs):
     """Create schedule for tensors"""
     s = tvm.create_schedule([x.op for x in outs])
index 95ce337..2403b01 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,unused-variable,unused-argument,no-member,import-outside-toplevel
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+# pylint: disable=no-value-for-parameter,import-outside-toplevel
 """Conv2D schedule on x86"""
 
 import logging
-import re
 
 import tvm
 from tvm import autotvm
-from tvm.autotvm.task.topi_integration import deserialize_args
-from tvm.autotvm.task import get_config
-from .. import generic, tag
 from .. import nn
-from ..nn.conv2d import conv2d, conv2d_NCHWc, \
-    conv2d_infer_layout, _get_workload as _get_conv2d_workload
+from ..nn.conv2d import conv2d_infer_layout, _get_workload as _get_conv2d_workload
+from ..nn.conv2d import unpack_NCHWc_to_nchw
 from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
-from ..nn.pad import pad
 from ..nn.util import get_pad_tuple
-from ..util import get_const_tuple
-
+from ..util import get_const_tuple, traverse_inline
 from . import conv2d_avx_1x1, conv2d_avx_common
 
 logger = logging.getLogger('topi')
@@ -61,219 +56,30 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
         else:
             conv2d_avx_common._fallback_schedule(cfg, wkl)
 
-def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
-    """Create schedule configuration from input arguments"""
-    dshape = get_const_tuple(data.shape)
-    kshape = get_const_tuple(kernel.shape)
-    pat = re.compile(r'NCHW.+(\d+)c')
-    if layout == 'NCHW':
-        n, ic, h, w = dshape
-        oc, _, kh, kw = kshape
-    elif layout == 'NHWC':
-        n, h, w, ic = dshape
-        kh, kw, oc, _ = kshape
-    elif pat.match(layout) is not None:
-        n, ic_chunk, h, w, ic_bn = dshape
-        target = tvm.target.Target.current(allow_none=False)
-        oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
-        assert ic_chunk == k_ic_chunk
-        assert ic_bn == k_ic_bn
-        ic = ic_chunk*ic_bn
-        oc = oc_chunk*oc_bn
-    else:
-        raise ValueError("Not support this layout {} with "
-                         "schedule template.".format(layout))
-
-    is_kernel_1x1 = kh == 1 and kw == 1
-    pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw))
-    sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
-    oh = (h - kh + pt + pb) // sh + 1
-    ow = (w - kw + pl + pr) // sw + 1
-
-    # Create schedule config
-    cfg.define_split("tile_ic", ic, num_outputs=2)
-    cfg.define_split("tile_oc", oc, num_outputs=2)
-    cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
-    if is_kernel_1x1:
-        cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1])
-    else:
-        cfg.define_knob("unroll_kw", [True, False])
-
-
-@autotvm.register_topi_compute(conv2d, 'cpu', ['direct'])
-def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
-    out_dtype = data.dtype if out_dtype is None else out_dtype
-    strides = strides if isinstance(strides, (tuple, list)) else (strides, strides)
-    dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
-
-    if layout == 'NCHW':
-        _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout)
-        if cfg.is_fallback:
-            _get_default_config(cfg, data, kernel, strides, padding, out_dtype)
-        return _declaration_conv_impl(cfg, data, kernel, strides,
-                                      padding, dilation, layout, out_dtype)
-
-    # HWOI kernel layout is for NHWC and HWCN
-    kh, kw, _, _ = get_const_tuple(kernel.shape)
-    if layout == 'HWCN':
-        return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
-    # FIXME - https://github.com/apache/incubator-tvm/issues/4122
-    # _declaration_conv_nhwc_pack expects kernel layout to be HWOI. However, the tests use HWIO
-    # layout. Commenting until we have clarity about the nhwc_pack implementation from the author.
-    # elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8":
-    #     if cfg.is_fallback:
-    #         _get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout)
-    #     # specialize for INT8 1X1 conv on X86
-    #     return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides,
-    #                                                       padding, dilation, out_dtype)
-    if layout == 'NHWC':
-        return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
-    raise ValueError("not support this layout {} yet".format(layout))
-
-
-def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
-    out_dtype = data.dtype if out_dtype is None else out_dtype
-    assert layout == 'NCHW', "only support NCHW convolution for AVX"
-
-    assert isinstance(dilation, int) or len(dilation) == 2
-    if isinstance(dilation, int):
-        dilation_h, dilation_w = dilation
-    else:
-        dilation_h, dilation_w = dilation
-
-    HSTR, WSTR = strides
-    batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape)
-    num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
-
-    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_height, kernel_width))
-    pad_h = pad_top + pad_down
-    pad_w = pad_left + pad_right
-
-    pad_height = in_height + pad_h
-    pad_width = in_width + pad_w
-
-    dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
-    dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
-    out_height = (in_height + pad_h - dilated_kernel_h) // HSTR + 1
-    out_width = (in_width + pad_w - dilated_kernel_w) // WSTR + 1
-
-    # pack data
-    DOPAD = (pad_h != 0 or pad_w != 0)
-    if DOPAD:
-        data_pad = pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right), \
-            name="data_pad")
-    else:
-        data_pad = data
-
-    # fetch schedule
-    ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
-
-    shape = (batch_size, in_channel // ic_bn, pad_height, ic_bn, pad_width)
-    data_vec = tvm.compute(shape,
-                           lambda n, C, h, c, w: data_pad[n, C * ic_bn + c, h, w],
-                           name='data_vec')
-
-    # pack kernel
-    shape = (num_filter//oc_bn, in_channel//ic_bn,
-             kernel_height, kernel_width, ic_bn, oc_bn)
-    kernel_vec = tvm.compute(shape,
-                             lambda CO, CI, h, w, ci, co:
-                             kernel[CO * oc_bn + co, CI * ic_bn + ci, h, w],
-                             name='kernel_vec')
-
-    # convolution
-    oshape = (batch_size, num_filter//oc_bn, out_height, out_width, oc_bn)
-    unpack_shape = (batch_size, num_filter, out_height, out_width)
-
-    ic = tvm.reduce_axis((0, in_channel), name='ic')
-    kh = tvm.reduce_axis((0, kernel_height), name='kh')
-    kw = tvm.reduce_axis((0, kernel_width), name='kw')
-    idxmod = tvm.indexmod
+@conv2d_infer_layout.register("cpu")
+def _conv2d_infer_layout(workload, cfg):
+    _, data, kernel, strides, padding, dilation, layout, _, dtype = workload
+    batch_size, in_channel, in_height, in_width = data[1]
+    out_channel, _, k_height, k_width = kernel[1]
     idxdiv = tvm.indexdiv
 
-    conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
-                       tvm.sum(data_vec[n, idxdiv(ic, ic_bn), oh*HSTR+kh*dilation_h,
-                                        idxmod(ic, ic_bn),
-                                        ow*WSTR+kw*dilation_w].astype(out_dtype) *
-                               kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kh, kw,
-                                          idxmod(ic, ic_bn),
-                                          oc_block].astype(out_dtype),
-                               axis=[ic, kh, kw]), name='conv')
-
-    unpack = tvm.compute(unpack_shape,
-                         lambda n, c, h, w: conv[n, idxdiv(c, oc_bn), h, w, idxmod(c, oc_bn)]
-                         .astype(out_dtype),
-                         name='output_unpack',
-                         tag='conv2d_nchw')
-    return unpack
-
-
-@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, 'cpu', ['direct'])
-def schedule_conv2d(cfg, outs):
-    """Create schedule for tensors"""
-    s = tvm.create_schedule([x.op for x in outs])
-    scheduled_ops = []
-
-    def traverse(op):
-        """Traverse operators from computation graph"""
-        # inline all one-to-one-mapping operators except the last stage (output)
-        if tag.is_broadcast(op.tag):
-            if op not in s.outputs:
-                s[op].compute_inline()
-            for tensor in op.input_tensors:
-                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
-                    traverse(tensor.op)
-
-        if 'conv2d_nchw' in op.tag:
-            output = op.output(0)
-            conv_out = op.input_tensors[0]
-            kernel_vec = conv_out.op.input_tensors[1]
-            kernel = kernel_vec.op.input_tensors[0]
-            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
-                s[kernel].compute_inline()
-            data_vec = conv_out.op.input_tensors[0]
-            data = data_vec.op.input_tensors[0]
-            data_pad = None
-            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
-                data_pad = data
-                data = data_pad.op.input_tensors[0]
-
-            _, _, kh, kw = get_const_tuple(kernel.shape)
-            is_kernel_1x1 = kh == 1 and kw == 1
-            args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
-            if is_kernel_1x1:
-                conv2d_avx_1x1._schedule_conv(*args)
-            else:
-                conv2d_avx_common._schedule_conv(*args)
-
-        scheduled_ops.append(op)
-
-    traverse(outs[0].op)
-    return s
+    pt, pl, pb, pr = get_pad_tuple(padding, (k_height, k_width))
+    out_height = idxdiv(in_height + pt + pb - k_height, strides[0]) + 1
+    out_width = idxdiv(in_width + pl + pr - k_width, strides[1]) + 1
+    tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+    in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
+    in_layout = "NCHW%dc" % tile_ic
+    out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc)
+    out_layout = "NCHW%dc" % tile_oc
+    return ((in_shape, in_layout),), ((out_shape, out_layout),)
 
-@generic.schedule_conv2d_nhwc.register("cpu")
 def schedule_conv2d_nhwc(outs):
-    """Create schedule for tensors"""
+    """Create schedule for conv2d_nhwc"""
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
     output_op = outs[0].op
-    scheduled_ops = []
-
-    def traverse(op):
-        """Traverse operators from computation graph"""
-        # inline all one-to-one-mapping operators except the last stage (output)
-        if tag.is_broadcast(op.tag):
-            if op not in s.outputs:
-                s[op].compute_inline()
-            else: # inject custom schedule
-                if len(op.axis) == 4: # schedule bias + bn + relu
-                    n, h, w, c = op.axis
-                    fused = s[op].fuse(n, h, w)
-                    s[op].parallel(fused)
-                    s[op].vectorize(c)
-            for tensor in op.input_tensors:
-                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
-                    traverse(tensor.op)
 
+    def _callback(op):
         if 'conv2d_nhwc' in op.tag:
             conv = op.output(0)
             kernel = op.input_tensors[1]
@@ -291,146 +97,133 @@ def schedule_conv2d_nhwc(outs):
             s[data_pad].parallel(pad_fused)
             C = conv
             n, h, w, c = C.op.axis
-            ry, rx, rc = C.op.reduce_axis
-            n_out, h_out, w_out, c_out = output_op.axis
             s[C].vectorize(c)
-            if op != output_op: # fuse bias + bn + relu into conv
-                s[C].compute_at(s[output_op], c_out)
-            else:
-                fused = s[C].fuse(n, h, w)
-                s[C].parallel(fused)
-
-        scheduled_ops.append(op)
 
-    traverse(output_op)
+            O = output_op.output(0)
+            if len(O.op.axis) == 4: # schedule bias + bn + relu
+                n, h, w, c = O.op.axis
+                fused = s[O].fuse(n, h, w)
+                s[O].parallel(fused)
+                channels = int(O.shape[-1])
+                if channels % 64 == 0:
+                    c, ci = s[O].split(c, 64)
+                    s[O].vectorize(ci)
+                if C != O:
+                    s[C].compute_at(s[O], c)
+
+    traverse_inline(s, output_op, _callback)
     return s
 
+def conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype):
+    layout = "NCHW"
+    packed_out = conv2d_NCHWc(data, kernel, strides, padding, dilation,
+                              layout, layout, out_dtype)
+    return unpack_NCHWc_to_nchw(packed_out, out_dtype)
 
-# Define template function for autotvm task
-# We define schedule template in this function instead of
-# declaration function since actual input arguments need
-# to be altered by the schedule selected.
-@autotvm.task.register("topi_x86_conv2d_NCHWc")
-def _topi_nn_conv2d_NCHWc(*args, **kwargs):
-    assert not kwargs, "Do not support kwargs in template function call"
-    args = deserialize_args(args)
-
-    if len(args) == 7:
-        data, kernel, strides, padding, dilation, origin_layout, dtype = args
-    else:
-        assert len(args) == 8
-        data, kernel, strides, padding, dilation, origin_layout, out_layout, dtype = args
-
-    raw_data_shape = get_const_tuple(data.shape)
-    raw_kernel_shape = get_const_tuple(kernel.shape)
-
-    # get config here
-    cfg = get_config()
-    _create_tuning_space(cfg, data, kernel, strides, padding, dilation, origin_layout)
-
-    idxdiv = tvm.indexdiv
-    idxmod = tvm.indexmod
-    # change shape with the value in config
-    ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
-                           cfg["tile_ow"].size[-1])
-    new_data_shape = (raw_data_shape[0], idxdiv(raw_data_shape[1], ic_bn),
-                      raw_data_shape[2], raw_data_shape[3], ic_bn)
-    data_layout = "NCHW%dc" % ic_bn
-    out_layout = "NCHW%dc" % oc_bn
-    new_kernel_shape = (idxdiv(raw_kernel_shape[0], oc_bn),
-                        idxdiv(raw_kernel_shape[1], ic_bn),
-                        raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
-    new_data = tvm.placeholder(new_data_shape, data.dtype)
-    new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
+def schedule_conv2d_nchw(outs):
+    """Create schedule for tensors"""
+    return schedule_conv2d_NCHWc(outs)
 
-    C = _declaration_conv_NCHWc(cfg, new_data, new_kernel, strides, padding, dilation,
-                                data_layout, out_layout, dtype)
-    s = _schedule_conv2d_NCHWc(cfg, [C])
-    return s, [new_data, new_kernel, C]
+def _pack_data(cfg, data, kernel):
+    n, _, ih, iw = get_const_tuple(data.shape)
+    oc, ic, kh, kw = get_const_tuple(kernel.shape)
+    ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
 
+    ic_chunk = ic // ic_bn
+    oc_chunk = oc // oc_bn
 
-@conv2d_infer_layout.register("cpu")
-def _conv2d_infer_layout(workload, cfg):
-    _, data, kernel, strides, padding, dilation, layout, dtype = workload
-    batch_size, in_channel, in_height, in_width = data[:-1]
-    out_channel, _, k_height, k_width = kernel[:-1]
-    idxdiv = tvm.indexdiv
+    data = tvm.compute((n, ic_chunk, ih, iw, ic_bn),
+                       lambda bs, c, h, w, vc: data[bs, c*ic_bn + vc, h, w],
+                       name="data_vec")
 
-    pt, pl, pb, pr = get_pad_tuple(padding, (k_height, k_width))
-    out_height = idxdiv(in_height + pt + pb - k_height, strides[0]) + 1
-    out_width = idxdiv(in_width + pl + pr - k_width, strides[1]) + 1
-    tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
-    in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
-    in_layout = "NCHW%dc" % tile_ic
-    out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc)
-    out_layout = "NCHW%dc" % tile_oc
-    return ((in_shape, in_layout),), ((out_shape, out_layout),)
+    kernel = tvm.compute(
+        (oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn),
+        lambda occ, icc, k_h, k_w, icb, ocb:
+        kernel[occ * oc_bn + ocb, icc * ic_bn + icb, k_h, k_w],
+        name="kernel_vec")
 
+    return data, kernel
 
-@autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct')
-def _declaration_conv_NCHWc(cfg, data, kernel, strides,
-                            padding, dilation, layout, out_layout, out_dtype):
+@autotvm.register_topi_compute("conv2d_NCHWc.x86")
+def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype):
+    """Compute conv2d with NCHWc layout."""
     # layout and out_layout are not used here,
     # we keep them for debug convenience when dumping autotvm workload
-    n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
-    in_channel = ic_chunk * ic_bn
-    oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
+    if len(data.shape) == 5:
+        n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
+        oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
             get_const_tuple(kernel.shape)
-    num_filter = oc_chunk * oc_bn
+        in_channel = ic_chunk * ic_bn
+        num_filter = oc_chunk * oc_bn
+    else:
+        n, in_channel, ih, iw = get_const_tuple(data.shape)
+        num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
+
+    # Define autotvm tuning space
+    is_kernel_1x1 = kernel_height == 1 and kernel_width == 1
+    pt, pl, pb, pr = get_pad_tuple(padding, (kernel_height, kernel_width))
+    sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+    oh = (ih - kernel_height + pt + pb) // sh + 1
+    ow = (iw - kernel_width + pl + pr) // sw + 1
+
+    cfg.define_split("tile_ic", in_channel, num_outputs=2)
+    cfg.define_split("tile_oc", num_filter, num_outputs=2)
+    cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
+    if is_kernel_1x1:
+        cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1])
+    else:
+        cfg.define_knob("unroll_kw", [True, False])
 
-    # If no config was set, we can fallback to NCHW config.
+    # If no config was set, we can fallback to default config.
     if cfg.is_fallback:
         _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
                             tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
                                             dtype=kernel.dtype),
                             strides, padding, out_dtype)
 
-    return nn.conv2d_NCHWc_compute(data,
-                                   kernel,
-                                   strides,
-                                   padding,
-                                   dilation,
-                                   layout,
-                                   out_layout,
-                                   out_dtype)
-
-
-@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'cpu', ['direct'])
-def _schedule_conv2d_NCHWc(cfg, outs):
+    # Pack data if raw 4-D data is provided.
+    # This can only happen when autotuning.
+    if len(data.shape) == 4:
+        data, kernel = _pack_data(cfg, data, kernel)
+
+    return nn.conv2d_NCHWc(data,
+                           kernel,
+                           strides,
+                           padding,
+                           dilation,
+                           layout,
+                           out_layout,
+                           out_dtype)
+
+@autotvm.register_topi_schedule("conv2d_NCHWc.x86")
+def schedule_conv2d_NCHWc(cfg, outs):
     """Create schedule for tensors"""
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
-    scheduled_ops = []
-
-    def traverse(op):
-        """Traverse operators from computation graph"""
-        # inline all one-to-one-mapping operators except the last stage (output)
-        if tag.is_broadcast(op.tag):
-            if op not in s.outputs:
-                s[op].compute_inline()
-            for tensor in op.input_tensors:
-                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
-                    traverse(tensor.op)
 
+    def _callback(op):
         if 'conv2d_NCHWc' in op.tag:
             conv_out = op.output(0)
-            kernel = conv_out.op.input_tensors[1]
+            kernel_vec = conv_out.op.input_tensors[1]
             data_vec = conv_out.op.input_tensors[0]
-            data = data_vec.op.input_tensors[0] \
-                if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
-                else data_vec
-            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
-                data_pad = data
-                data = data_pad.op.input_tensors[0]
 
-            args = [s, cfg, data_vec, conv_out, outs[0]]
-            target = tvm.target.Target.current(allow_none=False)
-            _, _, kh, kw, _, _, = get_const_tuple(kernel.shape)
+            args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]]
+            _, _, kh, kw, _, _, = get_const_tuple(kernel_vec.shape)
             if kh == 1 and kw == 1:
                 conv2d_avx_1x1._schedule_conv_NCHWc(*args)
             else:
                 conv2d_avx_common._schedule_conv_NCHWc(*args)
 
-        scheduled_ops.append(op)
-
-    traverse(outs[0].op)
+    traverse_inline(s, outs[0].op, _callback)
     return s
+
+
+# FIXME - https://github.com/apache/incubator-tvm/issues/4122
+# _declaration_conv_nhwc_pack expects kernel layout to be HWOI. However, the tests use HWIO
+# layout. Commenting until we have clarity about the nhwc_pack implementation from the author.
+# elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8":
+#     if cfg.is_fallback:
+#         _get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout)
+#     # specialize for INT8 1X1 conv on X86
+#     return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides,
+#                                                       padding, dilation, out_dtype)
index 8b0c13c..377d815 100644 (file)
@@ -23,117 +23,103 @@ import tvm
 from tvm import relay
 from tvm import autotvm
 from .conv2d import _get_default_config
-from .conv2d_int8 import _is_int8_hw_support, _get_default_config_int8
-from ..util import get_const_tuple, get_shape
-from ..nn import conv2d_legalize
-from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, conv2d_alter_layout
-from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
+from .conv2d_int8 import is_int8_hw_support, _get_default_config_int8
+from ..util import get_const_tuple
+from ..nn import conv2d_legalize, conv2d_alter_layout
 from ..nn.util import get_pad_tuple
 
 logger = logging.getLogger('topi')
 
 @conv2d_alter_layout.register("cpu")
-def _alter_conv2d_layout(attrs, inputs, tinfo, F):
+def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
+    target = tvm.target.Target.current(allow_none=False)
+    dispatch_ctx = autotvm.task.DispatchContext.current
+    if isinstance(dispatch_ctx, autotvm.task.ApplyGraphBest):
+        cfg = dispatch_ctx.query(target, None)
+        workload = cfg.workload
+    else:
+        _, outs = relay.backend.compile_engine.select_implementation(
+            relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
+        workload = autotvm.task.get_workload(outs)
+        if workload is None:
+            # The best implementation is not an AutoTVM template,
+            # we then assume it's not necessary to alter this op.
+            return None
+        cfg = dispatch_ctx.query(target, workload)
+
+    topi_tmpl = workload[0]
+    new_attrs = {k : attrs[k] for k in attrs.keys()}
+
     # Parse the attributes.
-    groups = attrs.get_int("groups")
     padding = attrs.get_int_tuple("padding")
     strides = attrs.get_int_tuple("strides")
     dilation = attrs.get_int_tuple("dilation")
-    out_dtype = attrs["out_dtype"]
-    layout_name = 'data_layout'
-    data_layout = attrs[layout_name]
-    kh, kw = attrs.get_int_tuple("kernel_size")
-
-    data_tensor, kernel_tensor = tinfo[0], tinfo[1]
-    if attrs[layout_name] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
-        batch_size, height, width, in_channel = get_const_tuple(data_tensor.shape)
-        kh, kw, _, out_channel = get_const_tuple(kernel_tensor.shape)
-    elif attrs[layout_name] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
-        batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
-        out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
-    else:
-        return None
-
+    data_layout = attrs["data_layout"]
+    kernel_layout = attrs["kernel_layout"]
+    data_tensor, kernel_tensor = tinfos
     data_dtype = data_tensor.dtype
     kernel_dtype = kernel_tensor.dtype
-    out_dtype = data_dtype if out_dtype in ("same", "") else out_dtype
-
-    # Check if depthwise.
-    kshape = get_shape(kernel_tensor.shape, attrs["kernel_layout"], "OIHW")
-    is_depthwise = groups == kshape[0] and kshape[1] == 1
-
-    # Save the input exprs.
-    copy_inputs = list(inputs)
-
-    # Set the new attrs
-    new_attrs = {k : attrs[k] for k in attrs.keys()}
-    new_attrs['channels'] = out_channel
-
-    # Return if the groups is not 1 and depthwise.
-    if groups != 1 and not is_depthwise:
-        return None
-
-    # Set workload. Config update.
-    dispatch_ctx = autotvm.task.DispatchContext.current
-    target = tvm.target.Target.current()
-
-    if is_depthwise:
-        workload = autotvm.task.args_to_workload(
-            [data_tensor, kernel_tensor, strides, padding, dilation, out_dtype],
-            depthwise_conv2d_nchw)
-    else:
-        workload = autotvm.task.args_to_workload(
-            [data_tensor, kernel_tensor, strides, padding, dilation, data_layout, out_dtype],
-            conv2d)
-
-    cfg = dispatch_ctx.query(target, workload)
-    if cfg.is_fallback:
-        if _is_int8_hw_support(data_dtype, kernel_dtype):
-            _get_default_config_int8(cfg, data_tensor, kernel_tensor, strides, padding, out_dtype,
-                                     is_depthwise, data_layout)
-        else:
-            _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, out_dtype,
-                                is_depthwise, data_layout)
+    out_dtype = out_type.dtype
+
+    if topi_tmpl == "conv2d_NCHWc.x86":
+        # we only convert conv2d_NCHW to conv2d_NCHWc for x86
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        if cfg.is_fallback:
+            _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
+                                out_dtype, False, data_layout)
+        batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
+        out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
+        ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
 
-    # Get the tiling parameters to set the layout names.
-    ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
-    new_attrs[layout_name] = 'NCHW%dc' % ic_bn
-    new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
-    new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
-                               dtype=data_dtype)
+        # update new attrs
+        new_attrs['channels'] = out_channel
+        new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
+        # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
+        new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
+        new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
 
-    if is_depthwise and data_layout == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
-        new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn
         # Store altered operator's config
-        new_kernel = tvm.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype)
+        new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
+                                   dtype=data_dtype)
+        new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
+                                      kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
         new_workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
-             new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc)
+            [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"],
+             new_attrs["out_layout"], out_dtype], topi_tmpl)
         dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)
 
-        return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs)
+    if topi_tmpl == "conv2d_NCHWc_int8.x86":
+        # TODO(@icemelon9, @anijain2305): Need to support data layout NHWC with kernel layout HWIO
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        if cfg.is_fallback:
+            _get_default_config_int8(cfg, data_tensor, kernel_tensor, strides, padding,
+                                     out_dtype, False, data_layout)
 
-    if _is_int8_hw_support(data_dtype, kernel_dtype):
-        # Convert kernel data layout from 4D to 7D
+        batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
+        out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
+        ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
         n_elems = 4
-        data_expr, kernel_expr = inputs
-        if attrs['kernel_layout'] == 'HWIO':
-            kernel_IHWO = F.transpose(kernel_expr, axes=(2, 0, 1, 3))
-        elif attrs['kernel_layout'] == 'OIHW':
-            kernel_IHWO = F.transpose(kernel_expr, axes=(1, 2, 3, 0))
-        else:
-            return None
-
-        kernel_IHWOo = F.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn))
-        kernel_OHWoI = F.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
-        kernel_OHWoIi = F.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn,
-                                                 in_channel//ic_bn, ic_bn))
-        kernel_OHWoIie = F.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn,
-                                                   in_channel//ic_bn, ic_bn//n_elems, n_elems))
-        kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))
-        copy_inputs = [data_expr, kernel_OIHWioe]
 
-        # Store altered operator's config. New kernel layout OIHWio4
+        # convert kernel data layout from 4D to 7D
+        data_expr, kernel_expr = inputs
+        kernel_IHWO = relay.transpose(kernel_expr, axes=(1, 2, 3, 0))
+        kernel_IHWOo = relay.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn))
+        kernel_OHWoI = relay.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
+        kernel_OHWoIi = relay.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn,
+                                                     in_channel//ic_bn, ic_bn))
+        kernel_OHWoIie = relay.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn,
+                                                       in_channel//ic_bn, ic_bn//n_elems, n_elems))
+        kernel_OIHWioe = relay.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))
+
+        # update new attrs
+        new_attrs['channels'] = out_channel
+        new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
+        new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+
+        # Store altered operator's config.
+        new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
+                                   dtype=data_dtype)
         new_kernel = tvm.placeholder((out_channel // oc_bn,
                                       in_channel // ic_bn,
                                       kh,
@@ -141,30 +127,41 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
                                       ic_bn // n_elems,
                                       oc_bn,
                                       n_elems), dtype=kernel_dtype)
-
-        new_workload = autotvm.task.args_to_workload([new_data,
-                                                      new_kernel,
-                                                      strides,
-                                                      padding,
-                                                      dilation,
-                                                      new_attrs[layout_name],
-                                                      new_attrs['out_layout'],
-                                                      out_dtype],
-                                                     conv2d_NCHWc_int8)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'],
+             new_attrs['out_layout'], out_dtype], topi_tmpl)
         dispatch_ctx.update(target, new_workload, cfg)
-        return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs)
 
-    # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
-    new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
-    # Store altered operator's config
-    new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
-                                  kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
-    new_workload = autotvm.task.args_to_workload(
-        [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
-         new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
-    dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs)
+
+    if topi_tmpl == "depthwise_conv2d_NCHWc.x86":
+        assert data_layout == "NCHW" and kernel_layout == "OIHW"
+        if cfg.is_fallback:
+            _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
+                                out_dtype, True, data_layout)
+
+        batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
+        out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
+        ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+        assert channel_multiplier == 1
+
+        # update new attrs
+        new_attrs['channels'] = out_channel
+        new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
+        new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn
+        new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
 
-    return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
+        # Store altered operator's config.
+        new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
+                                   dtype=data_dtype)
+        new_kernel = tvm.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'],
+             new_attrs['out_layout'], out_dtype], topi_tmpl)
+        dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
+
+    return None
 
 
 @conv2d_legalize.register("cpu")
@@ -254,7 +251,7 @@ def _conv2d_legalize(attrs, inputs, arg_types):
     # input channel to be a multiple of 4 and output channels to be a multiple of 16. For input
     # channels, we pad both the inputs and weights input channels. For output channels, we pad the
     # weight and stride_slice the output.
-    if _is_int8_hw_support(data_dtype, kernel_dtype):
+    if is_int8_hw_support(data_dtype, kernel_dtype):
         # Flags to remember if the expr is modified
         ic_modified = False
         oc_modified = False
index 9726f3d..083fff4 100644 (file)
 """1x1 Conv2D schedule on for Intel CPU"""
 from __future__ import absolute_import as _abs
 import tvm
+from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 
 from ..nn.pad import pad
-from ..nn.util import infer_pad, get_pad_tuple
+from ..nn.util import get_pad_tuple
 from ..generic import conv2d as conv2d_generic
 from ..util import get_const_tuple, simplify
 from .tensor_intrin import dot_16x1x16_uint8_int8_int32
@@ -58,84 +59,41 @@ def _fallback_schedule(cfg, wkl):
     raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
 
 
-def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
-    # fetch schedule
-    ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
-                                          cfg["tile_oh"].val, cfg["tile_ow"].size[-1])
-
-    # no stride and padding info here
-    padding = infer_pad(data, data_pad)
-    HPAD, WPAD = padding
-    DOPAD = (HPAD != 0 or WPAD != 0)
-
-    A, W = data, kernel_vec
-    A0, A1 = data_pad, data_vec
-    # schedule data
-    if DOPAD:
-        s[A0].compute_inline()
-    batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
-    parallel_axis = s[A1].fuse(batch, ic_chunk, ih)
-    s[A1].parallel(parallel_axis)
-
-    # schedule kernel pack
-    oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
-    s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
-    if oc_bn > 1:
-        s[W].vectorize(oc_block)
-    parallel_axis = s[W].fuse(oc_chunk, oh)
-    s[W].parallel(parallel_axis)
-
-    C, O0, O = conv_out, output, last
-    CC = s.cache_write(C, 'global')
-
-    batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
-    oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
-    s[C].vectorize(oc_block)
-
-    s[CC].compute_at(s[C], oh_outer)
-    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
-    ic, _, _ = s[CC].op.reduce_axis
-
-    ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn)
-
-    oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
-    ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
-
-    s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block)
-    s[CC].vectorize(oc_block)
-
-    s[CC].unroll(ow_inner)
-    s[CC].unroll(oh_inner)
-
-    if O0 != O:
-        s[O0].compute_inline()
-    batch, oc, oh, ow = s[O].op.axis
-
-    oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
-    oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
-    ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
-    s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
-
-    parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
-    s[C].compute_at(s[O], parallel_axis)
-    s[O].vectorize(oc_block)
-
-    s[O].parallel(parallel_axis)
-
-    return s
-
-
-def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
+def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
     # fetch schedule
     oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
-    _, _, _, _, ic_bn = get_const_tuple(data.shape)
-
-    # schedule data
-    A = data
-    if isinstance(s[A].op, tvm.tensor.ComputeOp):
-        batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
-        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
-        s[A].parallel(parallel_axis)
+    _, _, _, _, ic_bn = get_const_tuple(data_vec.shape)
+
+    # schedule pad
+    if isinstance(s[data_vec].op, tvm.tensor.ComputeOp) \
+            and "pad" in data_vec.op.tag:
+        batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
+        parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
+        s[data_vec].parallel(parallel_axis)
+        data_vec = data_vec.op.input_tensors[0]
+
+    if autotvm.GLOBAL_SCOPE.in_tuning:
+        # only in autotuning, input data of conv2d_NCHWc will be 4-D.
+        # skip this part during tuning to make records accurate.
+        # this part will be folded during Relay fold_constant pass.
+        s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region")
+        s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region")
+    elif isinstance(kernel_vec.op, tvm.tensor.ComputeOp) and \
+            kernel_vec.name == 'kernel_vec':
+        # data and kernel are not pre-computed, schedule layout transform here.
+        # this should only be used by x86 conv2d_nchw, which is for
+        # testing purpose.
+        batch, ic_chunk, ih, ic_block, iw = s[data_vec].op.axis
+        parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
+        s[data_vec].parallel(parallel_axis)
+
+        oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
+        s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
+        oc_bn = cfg["tile_oc"].size[-1]
+        if oc_bn > 1:
+            s[kernel_vec].vectorize(oc_block)
+        parallel_axis = s[kernel_vec].fuse(oc_chunk, oh)
+        s[kernel_vec].parallel(parallel_axis)
 
     C, O = conv_out, last
     CC = s.cache_write(C, 'global')
@@ -167,22 +125,36 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
     s[CC].unroll(oh_inner)
 
     if C != O:
-        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
-        oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
-        ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
-        s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
-
-        parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
-        s[C].compute_at(s[O], parallel_axis)
-        s[O].vectorize(oc_block)
-        s[O].parallel(parallel_axis)
+        out_ndim = len(s[O].op.axis)
+        if out_ndim == 5:
+            batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+            oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
+            ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
+            s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
+
+            parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
+            s[C].compute_at(s[O], parallel_axis)
+            s[O].vectorize(oc_block)
+            s[O].parallel(parallel_axis)
+        elif out_ndim == 4:
+            batch, oc, oh, ow = s[O].op.axis
+            oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
+            oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
+            ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
+            s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
+            parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
+            s[C].compute_at(s[O], parallel_axis)
+            s[O].vectorize(oc_block)
+            s[O].parallel(parallel_axis)
+        else:
+            raise ValueError("Unsupported output ndim: %s" % out_ndim)
 
     return s
 
 
-def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
-    return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last,
-                                                           int32_lanes=16,
+def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last):
+    return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec,
+                                                           conv_out, last, int32_lanes=16,
                                                            intrin=dot_16x1x16_uint8_int8_int32())
 
 
index 7c5096d..085d0ae 100644 (file)
@@ -18,9 +18,9 @@
 """Conv2D schedule on for Intel CPU"""
 from __future__ import absolute_import as _abs
 import tvm
+from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 
-from ..nn.util import infer_pad
 from ..generic import conv2d as conv2d_generic
 from ..util import get_const_tuple
 from .tensor_intrin import dot_16x1x16_uint8_int8_int32
@@ -83,88 +83,42 @@ def _fallback_schedule_int8(cfg, wkl):
     cfg["unroll_kw"] = OtherOptionEntity(False)
 
 
-def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
-    # fetch schedule
-    ic_bn, oc_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
-                                      cfg["tile_ow"].size[-1], cfg["unroll_kw"].val)
-
-    # no stride and padding info here
-    padding = infer_pad(data, data_pad)
-    HPAD, WPAD = padding
-    DOPAD = (HPAD != 0 or WPAD != 0)
-
-    A, W = data, kernel_vec
-    A0, A1 = data_pad, data_vec
-
-    # schedule data
-    if DOPAD:
-        s[A0].compute_inline()
-    batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
-    parallel_axis = s[A1].fuse(batch, ic_chunk, ih)
-    s[A1].parallel(parallel_axis)
-
-    # schedule kernel pack
-    oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
-    s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
-    if oc_bn > 1:
-        s[W].vectorize(oc_block)
-    parallel_axis = s[W].fuse(oc_chunk, oh)
-    s[W].parallel(parallel_axis)
-
-    # schedule conv
-    C, O0, O = conv_out, output, last
-    CC = s.cache_write(C, 'global')
-
-    _, oc_chunk, oh, ow, oc_block = s[C].op.axis
-    ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
-    s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-    s[C].fuse(oc_chunk, oh)
-    s[C].vectorize(oc_block)
-
-    s[CC].compute_at(s[C], ow_chunk)
-    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
-    ic, kh, kw = s[CC].op.reduce_axis
-
-    ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
-    ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn)
-
-    if unroll_kw:
-        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block)
-        s[CC].unroll(kw)
-    else:
-        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block)
-
-    s[CC].fuse(oc_chunk, oh)
-    s[CC].vectorize(oc_block)
-    s[CC].unroll(ow_block)
-
-    if O0 != O:
-        s[O0].compute_inline()
-
-    batch, oc, oh, ow = s[O].op.axis
-    ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
-    oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
-    s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-    parallel_axis = s[O].fuse(batch, oc_chunk, oh)
-    s[C].compute_at(s[O], parallel_axis)
-    s[O].vectorize(oc_block)
-
-    s[O].parallel(parallel_axis)
-
-    return s
-
-
-def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
+def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
     # fetch schedule
     reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
-    _, _, _, _, ic_bn = get_const_tuple(data.shape)
+    _, _, _, _, ic_bn = get_const_tuple(data_vec.shape)
+
+    # schedule pad
+    if isinstance(s[data_vec].op, tvm.tensor.ComputeOp) \
+            and "pad" in data_vec.op.tag:
+        batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
+        parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
+        s[data_vec].parallel(parallel_axis)
+        data_vec = data_vec.op.input_tensors[0]
+
+    if autotvm.GLOBAL_SCOPE.in_tuning:
+        # only in autotuning, input data of conv2d_NCHWc will be 4-D.
+        # skip this part during tuning to make records accurate.
+        # this part will be folded during Relay fold_constant pass.
+        s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region")
+        s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region")
+    elif isinstance(kernel_vec.op, tvm.tensor.ComputeOp) and \
+            kernel_vec.name == 'kernel_vec':
+        # data and kernel are not pre-computed, schedule layout transform here.
+        # this should only be used by x86 conv2d_nchw, which is for
+        # testing purpose.
+        batch, ic_chunk, ih, ic_block, iw = s[data_vec].op.axis
+        parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
+        s[data_vec].parallel(parallel_axis)
+
+        oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
+        s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
+        oc_bn = cfg["tile_oc"].size[-1]
+        if oc_bn > 1:
+            s[kernel_vec].vectorize(oc_block)
+        parallel_axis = s[kernel_vec].fuse(oc_chunk, oh)
+        s[kernel_vec].parallel(parallel_axis)
 
-    # schedule data
-    A = data
-    if isinstance(s[A].op, tvm.tensor.ComputeOp):
-        batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
-        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
-        s[A].parallel(parallel_axis)
 
     # schedule 5-D NCHW[x]c conv
     C, O = conv_out, last
@@ -195,18 +149,31 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
     s[CC].unroll(ow_block)
 
     if C != O:
-        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
-        ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
-        s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-        parallel_axis = s[O].fuse(batch, oc_chunk, oh)
-        s[C].compute_at(s[O], parallel_axis)
-        s[O].vectorize(oc_block)
-        s[O].parallel(parallel_axis)
+        out_ndim = len(s[O].op.axis)
+        if out_ndim == 5:
+            batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+            ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
+            s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+            parallel_axis = s[O].fuse(batch, oc_chunk, oh)
+            s[C].compute_at(s[O], parallel_axis)
+            s[O].vectorize(oc_block)
+            s[O].parallel(parallel_axis)
+        elif out_ndim == 4:
+            batch, oc, oh, ow = s[O].op.axis
+            ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
+            oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
+            s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+            parallel_axis = s[O].fuse(batch, oc_chunk, oh)
+            s[C].compute_at(s[O], parallel_axis)
+            s[O].vectorize(oc_block)
+            s[O].parallel(parallel_axis)
+        else:
+            raise ValueError("Unsupported output ndim: %s" % out_ndim)
 
     return s
 
 
-def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
-    return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last,
-                                                              int32_lanes=16,
+def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last):
+    return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec,
+                                                              conv_out, last, int32_lanes=16,
                                                               intrin=dot_16x1x16_uint8_int8_int32())
index 20712d2..64fe92b 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,unused-variable,unused-argument,no-member, import-outside-toplevel
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+# pylint: disable=no-value-for-parameter,import-outside-toplevel
 """Conv2D int8 schedule on x86"""
 
-import re
 import tvm
 from tvm import autotvm
-from tvm.autotvm.task import get_config
-from tvm.autotvm.task.topi_integration import deserialize_args
 from ..nn.conv2d import _get_workload as _get_conv2d_workload
-from .. import generic, tag
+from .. import tag
 from ..generic import conv2d as conv2d_generic
 from ..nn.util import get_pad_tuple
-from ..util import get_const_tuple
-from ..nn.conv2d import conv2d_NCHWc_int8
+from ..nn.conv2d import unpack_NCHWc_to_nchw
 from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
+from ..util import get_const_tuple, traverse_inline
 from .. import nn
 from . import conv2d_avx_1x1, conv2d_avx_common
 
@@ -53,7 +51,7 @@ def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_
                 cfg, wkl, int32_lanes=16, num_int8_elements=4)
 
 
-def _is_int8_hw_support(data_dtype, kernel_dtype):
+def is_int8_hw_support(data_dtype, kernel_dtype):
     """
     Checks to ensure that we can use Intel DLBoost instructions
     1) The datatypes are correct.
@@ -76,150 +74,123 @@ def _is_int8_hw_support(data_dtype, kernel_dtype):
     return is_dtype_support and is_llvm_support and is_target_support
 
 
-def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, layout):
-    """Create schedule configuration from input arguments"""
-    dshape = get_const_tuple(data.shape)
-    kshape = get_const_tuple(kernel.shape)
-    pat = re.compile(r'NCHW.+(\d+)c')
-    if layout == 'NCHW':
-        n, ic, h, w = dshape
-        oc, _, kh, kw = kshape
-    elif layout == 'NHWC':
-        n, h, w, ic = dshape
-        kh, kw, oc, _ = kshape
-    elif pat.match(layout) is not None:
-        n, ic_chunk, h, w, ic_bn = dshape
-        target = tvm.target.Target.current(allow_none=False)
-        oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
-        ic = ic_chunk * ic_bn
-        assert ic == k_ic * k_ic_f * k_ic_s
-        oc = oc_chunk*oc_bn
+def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype):
+    """Compute conv2d with NCHW layout and int8 dtype"""
+    layout = "NCHW"
+    packed_out = conv2d_NCHWc_int8(data, kernel, strides, padding, dilation,
+                                   layout, layout, out_dtype)
+    return unpack_NCHWc_to_nchw(packed_out, out_dtype)
+
+
+def schedule_conv2d_nchw_int8(outs):
+    """Create the schedule for conv2d_nchw_int8"""
+    return schedule_conv2d_NCHWc_int8(outs)
+
+
+def _pack_data(cfg, data, kernel):
+    n_elems = 4
+    n, _, ih, iw = get_const_tuple(data.shape)
+    oc, ic, kh, kw = get_const_tuple(kernel.shape)
+    ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+
+    ic_chunk = ic // ic_bn
+    oc_chunk = oc // oc_bn
+
+    data = tvm.compute((n, ic_chunk, ih, iw, ic_bn),
+                       lambda bs, c, h, w, vc: data[bs, c*ic_bn + vc, h, w],
+                       name="data_vec")
+
+    kernel = tvm.compute(
+        (oc_chunk, ic_chunk, kh, kw, ic_bn//n_elems, oc_bn, n_elems),
+        lambda occ, icc, k_h, k_w, icbc, ocb, icbb:
+        kernel[occ * oc_bn + ocb,
+               icc * ic_bn + icbc * ic_bn//n_elems + icbb, k_h, k_w],
+        name="kernel_vec")
+
+    return data, kernel
+
+
+@autotvm.register_topi_compute("conv2d_NCHWc_int8.x86")
+def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding,
+                      dilation, layout, out_layout, out_dtype):
+    """Compute conv2d with NCHWc layout and int8 dtype"""
+    if len(data.shape) == 5:
+        n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
+        in_channel = ic_chunk * ic_bn
+        oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ \
+            = get_const_tuple(kernel.shape)
+        num_filter = oc_chunk * oc_bn
     else:
-        raise ValueError("Not support this layout {} with "
-                         "schedule template.".format(layout))
+        n, in_channel, ih, iw = get_const_tuple(data.shape)
+        num_filter, _, kernel_height, kernel_width = \
+            get_const_tuple(kernel.shape)
 
-    is_kernel_1x1 = kh == 1 and kw == 1
-    pt, pl, pb, pr = get_pad_tuple(padding, kernel)
+    # Define autotvm tuning space
+    is_kernel_1x1 = kernel_height == 1 and kernel_width == 1
+    pt, pl, pb, pr = get_pad_tuple(padding, (kernel_height, kernel_width))
     sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
-    oh = (h - kh + pt + pb) // sh + 1
-    ow = (w - kw + pl + pr) // sw + 1
+    oh = (ih - kernel_height + pt + pb) // sh + 1
+    ow = (iw - kernel_width + pl + pr) // sw + 1
 
-    # Create schedule config
-    cfg.define_split('tile_ic', ic, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0)
-    cfg.define_split('tile_oc', oc, num_outputs=2, filter=lambda y: y.size[-1] % 16 == 0)
+    cfg.define_split('tile_ic', in_channel, num_outputs=2,
+                     filter=lambda y: y.size[-1] % 4 == 0)
+    cfg.define_split('tile_oc', num_filter, num_outputs=2,
+                     filter=lambda y: y.size[-1] % 16 == 0)
     cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
     if is_kernel_1x1:
         cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1])
     else:
         cfg.define_knob("unroll_kw", [True, False])
 
-
-# Define template function for autotvm task
-# We define schedule template in this function instead of
-# declaration function since actual input arguments need
-# to be altered by the schedule selected.
-@autotvm.task.register("topi_x86_conv2d_NCHWc_int8")
-def _topi_nn_conv2d_NCHWc_int8(*args, **kwargs):
-    assert not kwargs, "Do not support kwargs in template function call"
-    args = deserialize_args(args)
-
-    if len(args) == 7:
-        data, kernel, strides, padding, dilation, origin_layout, dtype = args
-    else:
-        assert len(args) == 8
-        data, kernel, strides, padding, dilation, origin_layout, out_layout, dtype = args
-
-    raw_data_shape = get_const_tuple(data.shape)
-    raw_kernel_shape = get_const_tuple(kernel.shape)
-
-    # get config here
-    cfg = get_config()
-    _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, origin_layout)
-
-    # change shape with the value in config
-    ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
-                           cfg["tile_ow"].size[-1])
-
-    data_layout = "NCHW%dc" % ic_bn
-    out_layout = "NCHW%dc" % oc_bn
-
-    # Set up the new shape for data and kernel
-    new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
-                      raw_data_shape[2], raw_data_shape[3], ic_bn)
-    n_elems = 4
-    new_kernel_shape = (raw_kernel_shape[0] // oc_bn,
-                        raw_kernel_shape[1] // ic_bn,
-                        raw_kernel_shape[2],
-                        raw_kernel_shape[3],
-                        ic_bn // n_elems,
-                        oc_bn,
-                        n_elems)
-
-    new_data = tvm.placeholder(new_data_shape, data.dtype)
-    new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
-
-    C = _declaration_conv_NCHWc_int8(cfg, new_data, new_kernel, strides, padding, dilation,
-                                     data_layout, out_layout, dtype)
-    s = _schedule_conv2d_NCHWc_int8(cfg, [C])
-    return s, [new_data, new_kernel, C]
-
-
-@autotvm.register_topi_compute(conv2d_NCHWc_int8, 'cpu', 'direct')
-def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides,
-                                 padding, dilation, layout, out_layout, out_dtype):
-    return nn.conv2d_NCHWc_int8_compute(data,
-                                        kernel,
-                                        strides,
-                                        padding,
-                                        dilation,
-                                        layout,
-                                        out_layout,
-                                        out_dtype)
-
-
-@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc_int8, 'cpu', ['direct'])
-def _schedule_conv2d_NCHWc_int8(cfg, outs):
+    # If no config was set, we can fallback to default config.
+    if cfg.is_fallback:
+        _get_default_config_int8(
+            cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
+            tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
+                            dtype=kernel.dtype),
+            strides, padding, out_dtype)
+
+    # Pack data if raw 4-D data is provided.
+    # This can only happen when autotuning.
+    if len(data.shape) == 4:
+        data, kernel = _pack_data(cfg, data, kernel)
+
+    return nn.conv2d_NCHWc_int8(data,
+                                kernel,
+                                strides,
+                                padding,
+                                dilation,
+                                layout,
+                                out_layout,
+                                out_dtype)
+
+
+@autotvm.register_topi_schedule("conv2d_NCHWc_int8.x86")
+def schedule_conv2d_NCHWc_int8(cfg, outs):
     """Create schedule for tensors"""
     s = tvm.create_schedule([x.op for x in outs])
-    scheduled_ops = []
 
-    def traverse(op):
+    def _callback(op):
         """Traverse operators from computation graph"""
-        # inline all one-to-one-mapping operators except the last stage (output)
-        if tag.is_broadcast(op.tag):
-            if op not in s.outputs:
-                s[op].compute_inline()
-            for tensor in op.input_tensors:
-                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
-                    traverse(tensor.op)
-
         if 'conv2d_NCHWc_int8' in op.tag:
             conv_out = op.output(0)
-            kernel = conv_out.op.input_tensors[1]
+            kernel_vec = conv_out.op.input_tensors[1]
             data_vec = conv_out.op.input_tensors[0]
-            data = data_vec.op.input_tensors[0] \
-                if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
-                else data_vec
-            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
-                data_pad = data
-                data = data_pad.op.input_tensors[0]
 
-            args = [s, cfg, data_vec, conv_out, outs[0]]
-            target = tvm.target.Target.current(allow_none=False)
+            args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]]
             # int8 conv kernel is 7-dim
-            _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
+            _, _, kh, kw, _, _, _ = get_const_tuple(kernel_vec.shape)
             if kh == 1 and kw == 1:
                 conv2d_avx_1x1._schedule_conv_NCHWc_int8(*args)
             else:
                 conv2d_avx_common._schedule_conv_NCHWc_int8(*args)
 
-        scheduled_ops.append(op)
-
-    traverse(outs[0].op)
+    traverse_inline(s, outs[0].op, _callback)
     return s
 
-@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc_pack, 'cpu', ['direct'])
-def schedule_conv2d_nhwc_pack(cfg, outs):
+
+@autotvm.register_topi_schedule("conv2d_nhwc_pack_int8.x86")
+def schedule_conv2d_nhwc_pack_int8(cfg, outs):
     """Create schedule for tensors"""
     s = tvm.create_schedule([x.op for x in outs])
     output_op = outs[0].op
index 27fc0af..71f47d6 100644 (file)
 # pylint: disable=invalid-name,unused-variable,unused-argument,no-member
 """Conv2D Transpose schedule on x86"""
 import tvm
-from tvm import autotvm
-from .. import generic
-from ..util import get_const_tuple, traverse_inline
-from ..nn import conv2d_transpose_nchw_preprocess, conv2d_transpose_nchw
-from . import conv2d_avx_1x1, conv2d_avx_common
-from .conv2d import _declaration_conv_impl, \
-    _create_tuning_space as _create_tuning_space_conv2d, \
-    _get_default_config as _get_default_config_conv2d
+from ..util import traverse_inline
+from .. import nn
+from .conv2d import conv2d_nchw, schedule_conv2d_nchw
 
-
-@autotvm.register_topi_compute(conv2d_transpose_nchw, 'cpu', ['direct'])
-def _conv2d_transpose_nchw(cfg, data, kernel, strides, padding, out_dtype):
+def conv2d_transpose_nchw(data, kernel, strides, padding, out_dtype):
     data_pad, kernel_transform = \
-        conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype)
-    # reuse conv2d implementation
-    _create_tuning_space_conv2d(cfg, data_pad, kernel_transform, strides=(1, 1), \
-                                padding=(0, 0), dilation=(1, 1), layout="NCHW")
-    if cfg.is_fallback:
-        _get_default_config_conv2d(cfg, data_pad, kernel_transform, strides=(1, 1), \
-                                   padding=(0, 0), out_dtype=out_dtype, layout='NCHW')
-    return _declaration_conv_impl(cfg, data_pad, kernel_transform, strides=(1, 1), \
-                                  padding=(0, 0), dilation=(1, 1), layout="NCHW", \
-                                  out_dtype=out_dtype)
-
+        nn.conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype)
+    # reuse conv2d_nchw implementation
+    return conv2d_nchw(data_pad, kernel_transform, strides=(1, 1),
+                       padding=(0, 0), dilation=(1, 1), out_dtype=out_dtype)
 
-@autotvm.register_topi_schedule(generic.schedule_conv2d_transpose_nchw, 'cpu', ['direct'])
-def _schedule_conv2d_transpose_nchw(cfg, outs):
+def schedule_conv2d_transpose_nchw(outs):
     """Create schedule for tensors"""
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    s = tvm.create_schedule([x.op for x in outs])
-
+    s = schedule_conv2d_nchw(outs)
     def _callback(op):
-        # reuse conv2d schedule
-        if 'conv2d_nchw' in op.tag:
-            output = op.output(0)
+        if 'unpack_nchwc' in op.tag:
             conv_out = op.input_tensors[0]
             # retrieve data
             data_vec = conv_out.op.input_tensors[0]
             data_pad = data_vec.op.input_tensors[0]
             data_dilate = data_pad.op.input_tensors[0]
             s[data_dilate].compute_inline()
+            s[data_pad].compute_inline()
             # retrieve kernel
             kernel_vec = conv_out.op.input_tensors[1]
             kernel_transform = kernel_vec.op.input_tensors[0]
             s[kernel_transform].compute_inline()
-            # call conv2d schedule
-            _, _, kh, kw = get_const_tuple(kernel_transform.shape)
-            is_kernel_1x1 = kh == 1 and kw == 1
-            args = [s, cfg, data_dilate, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
-            if is_kernel_1x1:
-                conv2d_avx_1x1._schedule_conv(*args)
-            else:
-                conv2d_avx_common._schedule_conv(*args)
 
     traverse_inline(s, outs[0].op, _callback)
     return s
index 4a6664e..1e15650 100644 (file)
@@ -21,9 +21,7 @@ from collections import namedtuple
 import tvm
 from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
-from .. import generic
 from ..util import traverse_inline
-from ..nn.conv3d import conv3d, conv3d_ncdhw
 from ..nn.util import get_pad_tuple3d, infer_pad3d
 from ..nn.pad import pad
 from ..util import get_const_tuple, simplify, get_const_int
@@ -35,9 +33,8 @@ Workload3D = namedtuple('Workload',
                          'hkernel', 'wkernel', 'dpad', 'hpad', 'wpad',
                          'dstride', 'hstride', 'wstride'])
 
-@autotvm.register_topi_compute(conv3d, 'cpu', ['direct'])
-def _declaration_conv3d(cfg, data, kernel, strides, padding, dilation,
-                        layout, out_dtype):
+@autotvm.register_topi_compute("conv3d_ndhwc.x86")
+def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """3D convolution forward operator.
 
     Parameters
@@ -59,30 +56,24 @@ def _declaration_conv3d(cfg, data, kernel, strides, padding, dilation,
     dilation: int or a list/tuple of three ints
         dilation size, or [dilation_depth, dilation_height, dilation_width]
 
-    layout : str
-        layout of data
-
     Returns
     -------
     output : tvm.Tensor
         5-D with shape [batch, out_depth, out_height, out_width, out_channel] for NDHWC layout
         5-D with shape [batch, out_channel, out_depth, out_height, out_width] for NCDHW layout
     """
+    layout = "NDHWC"
     out_dtype = data.dtype if out_dtype is None else out_dtype
     strides = strides if isinstance(strides, (tuple, list)) else (strides, strides, strides)
     dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation, dilation)
 
-    if layout == 'NDHWC':
-        _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout)
-        if cfg.is_fallback:
-            _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout)
-        return _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
-    elif layout == 'NCDHW':
-        return conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype)
-    raise ValueError("Layout {} is not supported".format(layout))
+    _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout)
+    if cfg.is_fallback:
+        _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout)
+    return _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype)
 
 
-@autotvm.register_topi_schedule(generic.schedule_conv3d_ndhwc, 'cpu', ['direct'])
+@autotvm.register_topi_schedule("conv3d_ndhwc.x86")
 def schedule_conv3d_ndhwc(cfg, outs):
     """TOPI schedule callback for conv3d
     Parameters
@@ -120,7 +111,7 @@ def schedule_conv3d_ndhwc(cfg, outs):
     return s
 
 
-def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
+def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
     out_dtype = data.dtype if out_dtype is None else out_dtype
 
     assert isinstance(dilation, int) or len(dilation) == 3
index c6c3d5e..ea89cf4 100644 (file)
@@ -23,147 +23,9 @@ from tvm.autotvm.task.space import SplitEntity
 from tvm.contrib import cblas
 
 from .util import get_fp32_len
-from .. import generic, tag, nn
+from .. import generic, tag
 from ..util import traverse_inline, get_const_tuple
 
-@autotvm.register_topi_compute(nn.dense, "cpu", "direct")
-def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
-    target = tvm.target.Target.current()
-    if "cblas" in target.libs:
-        C = cblas.matmul(data, weight, False, True)
-        if bias is not None:
-            C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j],
-                            tag=tag.BROADCAST)
-        return C
-
-    M, _ = get_const_tuple(data.shape)
-    # Always use dense_nopack for dynamic input.
-    # This is a temporary for CV models.
-    # TODO(kevinthesun): use kernel dispatcher instead.
-    if isinstance(M, tvm.expr.Var):
-        return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype)
-
-    # For small batch sizes, don't pack weight into cache-friendly layout
-    # because of overhead in packing and limited reuse from batch dimension
-    # TODO(icemelon9): use a more systematic way to determine which schedule to use
-    if M <= 16:
-        return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype)
-    return _declaration_dense_pack(cfg, data, weight, bias, out_dtype)
-
-
-# Declare dense compute with packing weight into cache-friendly layout
-@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
-def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
-    if out_dtype is None:
-        out_dtype = data.dtype
-    M, K = get_const_tuple(data.shape) # batch, in_dim
-    N, _ = get_const_tuple(weight.shape) # out_dim
-    # create tuning space
-    cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=3)
-    cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=3)
-    cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2)
-    if cfg.is_fallback:
-        _default_dense_pack_config(cfg, M, N, K)
-
-    packw_bn = cfg["tile_x"].size[-1]
-    packw_shape = (N // packw_bn, K, packw_bn)
-    packw = tvm.compute(packw_shape,
-                        lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
-
-    idxdiv = tvm.indexdiv
-    idxmod = tvm.indexmod
-    k = tvm.reduce_axis((0, K), name="k")
-    C = tvm.compute((M, N),
-                    lambda y, x: tvm.sum(
-                        data[y, k].astype(out_dtype) *
-                        packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype),
-                        axis=k),
-                    tag="dense_pack")
-    if bias is not None:
-        C = tvm.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
-                        tag=tag.BROADCAST)
-    return C
-
-
-# Declare dense compute without packing weight
-@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
-def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
-    if out_dtype is None:
-        out_dtype = data.dtype
-    M, K = get_const_tuple(data.shape)
-    N, _ = get_const_tuple(weight.shape)
-    # create tuning space
-    cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2)
-    cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2)
-    cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2)
-    if cfg.is_fallback:
-        _default_dense_nopack_config(cfg, M, N, K)
-
-    vec = cfg["tile_k"].size[-1]
-    k = tvm.reduce_axis((0, K // vec), "k")
-    CC = tvm.compute((M, N, vec),
-                     lambda z, y, x: tvm.sum(
-                         data[z, k * vec + x].astype(out_dtype) *
-                         weight[y, k * vec + x].astype(out_dtype), axis=k))
-
-    kk = tvm.reduce_axis((0, vec), "kk")
-    C = tvm.compute((M, N),
-                    lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
-                    tag="dense_nopack")
-    if bias is not None:
-        C = tvm.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
-                        tag=tag.BROADCAST)
-
-    return C
-
-
-@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct")
-def _schedule_dense(cfg, outs):
-    target = tvm.target.Target.current()
-    if "cblas" in target.libs:
-        return generic.schedule_extern(outs)
-
-    s = tvm.create_schedule([x.op for x in outs])
-
-    def _callback(op):
-        if "dense_pack" in op.tag:
-            _schedule_dense_pack_template(cfg, s, op.output(0))
-        elif 'dense_nopack' in op.tag:
-            _schedule_dense_nopack_template(cfg, s, op.output(0))
-    traverse_inline(s, outs[0].op, _callback)
-    return s
-
-
-@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack")
-def _schedule_dense_pack(cfg, outs):
-    target = tvm.target.Target.current()
-    if "cblas" in target.libs:
-        return generic.schedule_extern(outs)
-
-    s = tvm.create_schedule([x.op for x in outs])
-
-    def _callback(op):
-        if "dense_pack" in op.tag:
-            _schedule_dense_pack_template(cfg, s, op.output(0))
-    traverse_inline(s, outs[0].op, _callback)
-    return s
-
-
-@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack")
-def _schedule_dense_nopack(cfg, outs):
-    target = tvm.target.Target.current()
-    if "cblas" in target.libs:
-        return generic.schedule_extern(outs)
-
-    s = tvm.create_schedule([x.op for x in outs])
-
-    def _callback(op):
-        if 'dense_nopack' in op.tag:
-            _schedule_dense_nopack_template(cfg, s, op.output(0))
-    traverse_inline(s, outs[0].op, _callback)
-    return s
-
-
 def _schedule_dense_pack_template(cfg, s, C):
     A, packedB = s[C].op.input_tensors
 
@@ -270,3 +132,106 @@ def _default_dense_nopack_config(cfg, M, N, K):
     cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn])
     cfg["tile_x"] = SplitEntity([N, 1])
     cfg["tile_y"] = SplitEntity([1, M])
+
+@autotvm.register_topi_compute("dense_nopack.x86")
+def dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
+    """Compute dense without packing"""
+    if out_dtype is None:
+        out_dtype = data.dtype
+    M, K = get_const_tuple(data.shape)
+    N, _ = get_const_tuple(weight.shape)
+    # create tuning space
+    cfg.define_split("tile_y", 32 if isinstance(M, tvm.expr.Var) else M, num_outputs=2)
+    cfg.define_split("tile_x", 32 if isinstance(N, tvm.expr.Var) else N, num_outputs=2)
+    cfg.define_split("tile_k", 32 if isinstance(K, tvm.expr.Var) else K, num_outputs=2)
+    if cfg.is_fallback:
+        _default_dense_nopack_config(cfg, M, N, K)
+
+    vec = cfg["tile_k"].size[-1]
+    k = tvm.reduce_axis((0, K // vec), "k")
+    CC = tvm.compute((M, N, vec),
+                     lambda z, y, x: tvm.sum(
+                         data[z, k * vec + x].astype(out_dtype) *
+                         weight[y, k * vec + x].astype(out_dtype), axis=k))
+
+    kk = tvm.reduce_axis((0, vec), "kk")
+    C = tvm.compute((M, N),
+                    lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
+                    tag="dense_nopack")
+    if bias is not None:
+        C = tvm.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
+                        tag=tag.BROADCAST)
+    return C
+
+
+@autotvm.register_topi_schedule("dense_nopack.x86")
+def schedule_dense_nopack(cfg, outs):
+    """Create the schedule for dense_nopack"""
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if 'dense_nopack' in op.tag:
+            _schedule_dense_nopack_template(cfg, s, op.output(0))
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+@autotvm.register_topi_compute("dense_pack.x86")
+def dense_pack(cfg, data, weight, bias=None, out_dtype=None):
+    """Compute dense with packing"""
+    if out_dtype is None:
+        out_dtype = data.dtype
+    M, K = get_const_tuple(data.shape) # batch, in_dim
+    N, _ = get_const_tuple(weight.shape) # out_dim
+    # create tuning space
+    cfg.define_split("tile_y", M, num_outputs=3)
+    cfg.define_split("tile_x", N, num_outputs=3)
+    cfg.define_split("tile_k", K, num_outputs=2)
+    if cfg.is_fallback:
+        _default_dense_pack_config(cfg, M, N, K)
+
+    packw_bn = cfg["tile_x"].size[-1]
+    packw_shape = (N // packw_bn, K, packw_bn)
+    packw = tvm.compute(packw_shape,
+                        lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
+
+    idxdiv = tvm.indexdiv
+    idxmod = tvm.indexmod
+    k = tvm.reduce_axis((0, K), name="k")
+    C = tvm.compute((M, N),
+                    lambda y, x: tvm.sum(
+                        data[y, k].astype(out_dtype) *
+                        packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype),
+                        axis=k),
+                    tag="dense_pack")
+    if bias is not None:
+        C = tvm.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
+                        tag=tag.BROADCAST)
+    return C
+
+@autotvm.register_topi_schedule("dense_pack.x86")
+def schedule_dense_pack(cfg, outs):
+    """Create the schedule for dense_pack"""
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if "dense_pack" in op.tag:
+            _schedule_dense_pack_template(cfg, s, op.output(0))
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+@autotvm.register_topi_compute("dense_cblas.x86")
+def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
+    """Compute dense using cblas library"""
+    M, K = get_const_tuple(data.shape)
+    N, _ = get_const_tuple(weight.shape)
+    cfg.add_flop(M * K * N * 2)
+    C = cblas.matmul(data, weight, False, True)
+    if bias is not None:
+        C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
+                        tag=tag.BROADCAST)
+    return C
+
+@autotvm.register_topi_schedule("dense_cblas.x86")
+def schedule_dense_cblas(_, outs):
+    """Create schedule for dense_cblas"""
+    return generic.schedule_extern(outs)
index 385537b..2aa5e74 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+# pylint: disable=no-value-for-parameter
 """Depthwise Conv2D schedule on x86"""
 import tvm
 from tvm import autotvm
-from tvm.autotvm.task import get_config
 from tvm.autotvm.task.space import SplitEntity
-from tvm.autotvm.task.topi_integration import deserialize_args
-from .. import generic, tag
-from ..generic import schedule_depthwise_conv2d_nchw
 from ..nn.pad import pad
 from ..util import get_const_tuple
 from ..nn.util import get_pad_tuple
-from ..nn.depthwise_conv2d import depthwise_conv2d_nchw, depthwise_conv2d_NCHWc, \
-    _get_workload, depthwise_conv2d_infer_layout
-
+from ..nn.depthwise_conv2d import _get_workload, depthwise_conv2d_infer_layout
+from ..nn.conv2d import unpack_NCHWc_to_nchw
+from ..util import traverse_inline
 from .util import get_fp32_len
 
 def _fallback_schedule(cfg, wkl):
@@ -70,20 +67,57 @@ def _fallback_schedule(cfg, wkl):
     cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
     cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
 
+def depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype):
+    """Compute depthwise conv2d with NCHW layout."""
+    layout = "NCHW"
+    packed_out = depthwise_conv2d_NCHWc(data, kernel, strides, padding, dilation,
+                                        layout, layout, out_dtype)
+    return unpack_NCHWc_to_nchw(packed_out, out_dtype)
+
+def schedule_depthwise_conv2d_nchw(outs):
+    """Create schedule for depthwise_conv2d_nchw."""
+    return schedule_depthwise_conv2d_NCHWc(outs)
+
+def _pack_data(cfg, data, kernel):
+    n, ic, ih, iw = get_const_tuple(data.shape)
+    filters, cm, kh, kw = get_const_tuple(kernel.shape)
+    oc = filters * cm
+    ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+
+    ic_chunk = ic // ic_bn
+    oc_chunk = oc // oc_bn
+
+    data = tvm.compute((n, ic_chunk, ih, iw, ic_bn),
+                       lambda bs, c, h, w, vc: data[bs, c*ic_bn + vc, h, w],
+                       name="data_vec")
 
-autotvm.register_topi_compute(depthwise_conv2d_nchw, 'cpu', 'direct',
-                              depthwise_conv2d_nchw.fdefault)
-autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'cpu', 'direct',
-                               schedule_depthwise_conv2d_nchw.fdefault)
+    kernel = tvm.compute(
+        (oc_chunk, 1, kh, kw, 1, oc_bn),
+        lambda occ, icc, k_h, k_w, icb, ocb:
+        kernel[(occ * oc_bn + ocb) // cm,
+               (occ * oc_bn + ocb) % cm, k_h, k_w],
+        name="kernel_vec")
 
+    return data, kernel
 
-@autotvm.register_topi_compute(depthwise_conv2d_NCHWc, 'cpu', 'direct')
-def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
-                                layout, out_layout, out_dtype=None):
+@autotvm.register_topi_compute("depthwise_conv2d_NCHWc.x86")
+def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
+                           layout, out_layout, out_dtype=None):
+    """Compute depthwise conv2d with NCHWc layout"""
     out_dtype = data.dtype if out_dtype is None else out_dtype
-    batch, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape)
-    out_channel_chunk, _, filter_height, filter_width, __, out_channel_block \
-        = get_const_tuple(kernel.shape)
+
+    if len(data.shape) == 5:
+        batch, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape)
+        out_channel_chunk, cm_chunk, filter_height, filter_width, cm_block, out_channel_block \
+            = get_const_tuple(kernel.shape)
+        in_channel = in_channel_chunk * in_channel_block
+        out_channel = out_channel_chunk * out_channel_block
+        channel_multiplier = cm_chunk * cm_block
+        assert channel_multiplier * in_channel == out_channel
+    else:
+        batch, in_channel, in_height, in_width = get_const_tuple(data.shape)
+        out_channel, channel_multiplier, filter_height, filter_width = get_const_tuple(kernel.shape)
+    assert channel_multiplier == 1
 
     strides = strides if isinstance(strides, (tuple, list)) else (strides, strides)
     HSTR, WSTR = strides
@@ -92,21 +126,30 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
     dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
     assert (dh, dw) == (1, 1), "Does not support dilation"
 
-    in_channel = in_channel_chunk * in_channel_block
-    out_channel = out_channel_chunk * out_channel_block
-    channel_multiplier = out_channel // in_channel
-
     out_height = (in_height - filter_height + pad_top + pad_down) // HSTR + 1
     out_width = (in_width - filter_width + pad_left + pad_right) // WSTR + 1
 
+    cfg.define_split("tile_ic", in_channel, num_outputs=2)
+    cfg.define_split("tile_oc", out_channel, num_outputs=2)
+    cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
+
     # get workload and related schedule config
-    wkl = _get_workload(tvm.placeholder((batch, in_channel, in_height, in_width), dtype=data.dtype),
-                        tvm.placeholder((out_channel, in_channel, filter_height, filter_width),
-                                        dtype=kernel.dtype),
-                        strides, padding, out_dtype)
+    wkl = _get_workload(
+        tvm.placeholder((batch, in_channel, in_height, in_width), dtype=data.dtype),
+        tvm.placeholder((out_channel, channel_multiplier, filter_height, filter_width),
+                        dtype=kernel.dtype),
+        strides, padding, out_dtype)
     if cfg.is_fallback:
         _fallback_schedule(cfg, wkl)
 
+    # Pack data if raw 4-D data is provided.
+    # This can only happen when autotuning.
+    if len(data.shape) == 4:
+        data, kernel = _pack_data(cfg, data, kernel)
+        _, _, _, _, in_channel_block = get_const_tuple(data.shape)
+        out_channel_chunk, _, _, _, _, out_channel_block \
+            = get_const_tuple(kernel.shape)
+
     # padding stage
     DOPAD = (pad_top != 0 or pad_left != 0 or pad_down != 0 or pad_right != 0)
     if DOPAD:
@@ -136,38 +179,39 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
         name='DepthwiseConv2d', tag="depthwise_conv2d_NCHWc")
     return Output
 
-
-@autotvm.register_topi_schedule(generic.schedule_depthwise_conv2d_NCHWc, 'cpu', ['direct'])
+@autotvm.register_topi_schedule("depthwise_conv2d_NCHWc.x86")
 def schedule_depthwise_conv2d_NCHWc(cfg, outs):
     """CPU schedule for depthwise conv2d in NCHW[x]c layout"""
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
-    scheduled_ops = []
-    def traverse(op):
+
+    def _callback(op):
         """Traverse operators from computation graph"""
-        # inline all one-to-one-mapping operators except the last stage (output)
-        if tag.is_broadcast(op.tag):
-            if op not in s.outputs:
-                s[op].compute_inline()
-            for tensor in op.input_tensors:
-                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
-                    traverse(tensor.op)
         if 'depthwise_conv2d_NCHWc' in op.tag:
             conv_out = op.output(0)
             data = conv_out.op.input_tensors[0]
             kernel = conv_out.op.input_tensors[1]
             _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data, kernel, conv_out, outs[0])
-        scheduled_ops.append(op)
-    traverse(outs[0].op)
+
+    traverse_inline(s, outs[0].op, _callback)
     return s
 
-def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data, kernel, conv_out, output):
-    tile_ow = cfg["tile_ow"].size[-1]
-    # schedule data
-    A = data
-    if isinstance(s[A].op, tvm.tensor.ComputeOp):
-        batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
-        p = s[A].fuse(ic_chunk, ih)
-        s[A].parallel(p)
+def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out, output):
+    tile_ow, oc_bn = cfg["tile_ow"].size[-1], cfg["tile_oc"].size[-1]
+    # schedule pad
+    if isinstance(s[data_vec].op, tvm.tensor.ComputeOp) \
+            and "pad" in data_vec.op.tag:
+        batch, ic_chunk, ih, iw, ic_block = s[data_vec].op.axis
+        parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
+        s[data_vec].parallel(parallel_axis)
+        data_vec = data_vec.op.input_tensors[0]
+
+    if autotvm.GLOBAL_SCOPE.in_tuning:
+        # only in autotuning, input data of conv2d_NCHWc will be 4-D.
+        # skip this part during tuning to make recrods accurate.
+        # this part will be folded during Relay fold_constant pass.
+        s[data_vec].pragma(s[data_vec].op.axis[0], "debug_skip_region")
+        s[kernel_vec].pragma(s[kernel_vec].op.axis[0], "debug_skip_region")
 
     C, O = conv_out, output
     CC = s.cache_write(C, 'global')
@@ -187,55 +231,34 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data, kernel, conv_out, output
     s[CC].unroll(ow_block)
 
     if C != O:
-        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
-        ow_chunk, ow_block = s[O].split(ow, factor=tile_ow)
-        s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
-        parallel_axis = s[O].fuse(oc_chunk, oh)
-        s[C].compute_at(s[O], parallel_axis)
-        s[O].vectorize(oc_block)
-        s[O].parallel(parallel_axis)
-    return s
-
-
-@autotvm.task.register("topi_x86_depthwise_conv2d_NCHWc_from_nchw")
-def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
-    assert not kwargs, "Do not support kwargs in template function call"
-    data, kernel, strides, padding, dilation, dtype = deserialize_args(args)
-
-    batch, in_channel, height, width = get_const_tuple(data.shape)
-    filter_channel, channel_multiplier, kh, kw = get_const_tuple(kernel.shape)
-    pt, pl, pb, pr = get_pad_tuple(padding, kernel)
-    sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
-    out_height = (height - kh + pt + pb) // sh + 1
-    out_width = (width - kw + pl + pr) // sw + 1
-    out_channel = filter_channel * channel_multiplier
-
-    # get config here
-    cfg = get_config()
-    cfg.define_split("tile_ic", in_channel, num_outputs=2)
-    cfg.define_split("tile_oc", out_channel, num_outputs=2)
-    cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
+        out_ndim = len(s[O].op.axis)
+        if out_ndim == 5:
+            batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+            ow_chunk, ow_block = s[O].split(ow, factor=tile_ow)
+            s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+            parallel_axis = s[O].fuse(oc_chunk, oh)
+            s[C].compute_at(s[O], parallel_axis)
+            s[O].vectorize(oc_block)
+            s[O].parallel(parallel_axis)
+        elif out_ndim == 4:
+            batch, oc, oh, ow = s[O].op.axis
+            ow_chunk, ow_block = s[O].split(ow, factor=tile_ow)
+            oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
+            s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+            parallel_axis = s[O].fuse(oc_chunk, oh)
+            s[C].compute_at(s[O], parallel_axis)
+            s[O].vectorize(oc_block)
+            s[O].parallel(parallel_axis)
+        else:
+            raise ValueError("Unsupported output ndim: %s" % out_ndim)
 
-    # change shape with the value in config
-    ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
-    new_data_shape = (batch, in_channel // ic_bn, height, width, ic_bn)
-    new_kernel_shape = (out_channel // oc_bn, 1, kh, kw, 1, oc_bn)
-    new_data = tvm.placeholder(new_data_shape, data.dtype)
-    new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
-
-    data_layout = "NCHW%dc" % ic_bn
-    out_layout = "NCHW%dc" % oc_bn
-
-    C = _depthwise_conv2d_NCHWc_cpu(cfg, new_data, new_kernel, strides, padding, dilation,
-                                    data_layout, out_layout, dtype)
-    s = schedule_depthwise_conv2d_NCHWc(cfg, [C])
-    return s, [new_data, new_kernel, C]
+    return s
 
 @depthwise_conv2d_infer_layout.register("cpu")
 def _depthwise_conv2d_infer_layout(workload, cfg):
     _, data, kernel, strides, padding, dilation, dtype = workload
-    batch_size, in_channel, in_height, in_width = data[:-1]
-    filter_channel, channel_multiplier, k_height, k_width = kernel[:-1]
+    batch_size, in_channel, in_height, in_width = data[1]
+    filter_channel, channel_multiplier, k_height, k_width = kernel[1]
     out_channel = filter_channel * channel_multiplier
     out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
     out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
index d6bb762..375827b 100644 (file)
 """x86 declaration and schedules."""
 from __future__ import absolute_import as _abs
 import tvm
-from .. import generic
 from ..util import is_empty_shape
 
-@generic.schedule_injective_from_existing.register(["cpu"])
 def schedule_injective_from_existing(sch, out):
     """Schedule for injective op from existing schedule.
 
@@ -53,7 +51,6 @@ def schedule_injective_from_existing(sch, out):
         sch[out].vectorize(li)
     return sch
 
-@generic.schedule_injective.register(["cpu"])
 def schedule_injective(outs):
     """X86 schedule for injective op.
 
@@ -77,7 +74,6 @@ def schedule_injective(outs):
         schedule_injective_from_existing(s, x)
     return s
 
-@generic.schedule_concatenate.register(["cpu"])
 def schedule_concatenate(outs):
     """X86 schedule for concatenate op.
 
index 45cb17e..3d57b6b 100644 (file)
@@ -18,9 +18,7 @@
 """x86 nn operators"""
 from __future__ import absolute_import as _abs
 import tvm
-from .. import generic
 
-@generic.schedule_softmax.register(["cpu"])
 def schedule_softmax(outs):
     """Schedule for softmax
 
index ed7d525..a8251dd 100644 (file)
@@ -17,7 +17,6 @@
 # pylint: disable=invalid-name, unused-variable
 """Schedule for pooling operators"""
 import tvm
-from .. import generic
 from .. import tag
 
 def _parallel_sch(sch, oshape, do_vectorize=False):
@@ -59,7 +58,6 @@ def _parallel_sch(sch, oshape, do_vectorize=False):
     sch.parallel(fused)
 
 
-@generic.schedule_pool.register(["cpu"])
 def schedule_pool(outs, layout):
     """Schedule for pool
 
@@ -117,7 +115,6 @@ def schedule_pool(outs, layout):
     return s
 
 
-@generic.schedule_adaptive_pool.register(["cpu"])
 def schedule_adaptive_pool(outs):
     """Schedule for adaptive pool
 
index f704d49..b9dd4d4 100644 (file)
@@ -18,8 +18,8 @@
 """x86 declaration and schedules."""
 from __future__ import absolute_import as _abs
 import tvm
+from .injective import schedule_injective_from_existing
 from .. import tag
-from .. import generic
 from ..util import get_const_tuple
 
 def _schedule_reduce(sch, op, is_idx_reduce=False):
@@ -58,7 +58,6 @@ def _schedule_reduce(sch, op, is_idx_reduce=False):
             sch[out].parallel(fused)
 
 
-@generic.schedule_reduce.register(["cpu"])
 def schedule_reduce(outs):
     """X86 schedule for reduction op.
 
@@ -95,7 +94,7 @@ def schedule_reduce(outs):
         """Internal traverse function"""
         if tag.is_broadcast(operator.tag):
             if operator not in scheduled_ops:
-                generic.schedule_injective_from_existing(sch, operator)
+                schedule_injective_from_existing(sch, operator)
             for tensor in operator.input_tensors:
                 traverse_after_reduce(tensor.op)
         elif operator.tag == 'comm_reduce':
index 26b84be..203c3dd 100644 (file)
@@ -20,7 +20,6 @@ import math
 import tvm
 
 from tvm import hybrid
-from ..vision.rcnn import roi_align_nchw
 from ..tensor import full
 from ..util import get_const_tuple
 
@@ -185,8 +184,7 @@ def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, samp
     return output
 
 
-@roi_align_nchw.register("cpu")
-def roi_align_nchw_cpu(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
+def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
     """ROI align operator in NCHW layout.
 
     Parameters
index c9e0e38..898d0e5 100644 (file)
 """sparse_dense schedule on x86"""
 import tvm
 
-from .. import generic
 from ..util import traverse_inline, get_const_int
 from .util import get_fp32_len
 
 
-@generic.schedule_sparse_dense.register(["cpu"])
-def _schedule_sparse_dense(outs):
+def schedule_sparse_dense(outs):
+    """Create schedule for sparse dense"""
     s = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
index a7b9160..79e223c 100644 (file)
@@ -677,7 +677,7 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax")
 
 TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
-  *rv = topi::rocm::schedule_lrn(args[0], args[1]);
+  *rv = topi::rocm::schedule_lrn(args[0]);
   });
 
 /* CUDA schedules */
@@ -723,7 +723,7 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax")
 
 TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
-  *rv = topi::cuda::schedule_lrn(args[0], args[1]);
+  *rv = topi::cuda::schedule_lrn(args[0]);
   });
 
 /* Utility functions */
index 4e0a45b..e03708c 100644 (file)
 # under the License.
 """Common utility for topi test"""
 
+import tvm
 from tvm import autotvm
 from tvm.autotvm.task.space import FallbackConfigEntity
-
+import topi
 
 def get_all_backend():
     """return all supported target
@@ -31,14 +32,12 @@ def get_all_backend():
     return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx',
             'llvm -device=arm_cpu', 'opencl -device=mali', 'aocl_sw_emu']
 
-
 class Int8Fallback(autotvm.FallbackContext):
     def _query_inside(self, target, workload):
         key = (target, workload)
         if key in self.memory:
             return self.memory[key]
         cfg = FallbackConfigEntity()
-        cfg.template_key = 'int8'
         self.memory[key] = cfg
         cfg.is_fallback = False
         return cfg
index 022272f..34c389a 100644 (file)
 
 import tvm
 import topi
+import topi.testing
 import numpy as np
-from common import get_all_backend
 from tvm.contrib.pickle_memoize import memoize
 
+from common import get_all_backend
+
 def verify_fifo_buffer(buffer_shape, data_shape, axis, dtype='float32'):
     buffer = tvm.placeholder(buffer_shape, name='buffer', dtype=dtype)
     data = tvm.placeholder(data_shape, name='data', dtype=dtype)
@@ -52,7 +54,7 @@ def verify_fifo_buffer(buffer_shape, data_shape, axis, dtype='float32'):
 
         with tvm.target.create(device):
             out = topi.nn.fifo_buffer(data, buffer, axis=axis)
-            s = topi.generic.schedule_injective([out])
+            s = topi.testing.get_injective_schedule(device)([out])
 
         buffer_tvm = tvm.nd.array(buffer_np, ctx=ctx)
         data_tvm = tvm.nd.array(data_np, ctx=ctx)
@@ -126,29 +128,29 @@ def verify_conv1d_integration():
             return
         print('  Running on target: {}'.format(device))
 
+        conv2d_nchw, schedule_conv2d_nchw = topi.testing.get_conv2d_nchw_implement(device)
+
         with tvm.target.create(device):
             out = topi.nn.fifo_buffer(inc_input, context, axis=buffer_axis)
-            s = topi.generic.schedule_injective([out])
+            s = topi.testing.get_injective_schedule(device)([out])
             update_context = tvm.build(s, [inc_input, context, out], device, name='update_context')
 
-            out = topi.nn.conv2d(context, kernel, strides=stride, padding=padding, dilation=dilate,
-                                 layout='NCHW', out_dtype=dtype)
-            s = topi.generic.schedule_conv2d_nchw([out])
+            out = conv2d_nchw(context, kernel, stride, padding, dilate, dtype)
+            s = schedule_conv2d_nchw([out])
             conv2d_inc = tvm.build(s, [context, kernel, out], device, name='conv2d_inc')
 
             out = topi.nn.fifo_buffer(inc_output, output_window, axis=buffer_axis)
-            s = topi.generic.schedule_injective([out])
+            s = topi.testing.get_injective_schedule(device)([out])
             update_output_window = tvm.build(s, [inc_output, output_window, out], device,
                  name='update_output_window')
 
             out = topi.nn.fifo_buffer(inc_input, input_window, axis=buffer_axis)
-            s = topi.generic.schedule_injective([out])
+            s = topi.testing.get_injective_schedule(device)([out])
             update_input_window = tvm.build(s, [inc_input, input_window, out], device,
                                             name='update_input_window')
 
-            out = topi.nn.conv2d(input_window, kernel, strides=stride, padding=padding,
-                                 dilation=dilate, layout='NCHW', out_dtype=dtype)
-            s = topi.generic.schedule_conv2d_nchw([out])
+            out = conv2d_nchw(input_window, kernel, stride, padding, dilate, dtype)
+            s = schedule_conv2d_nchw([out])
             conv2d = tvm.build(s, [input_window, kernel, out], device, name='conv2d')
 
         input_window_tvm = tvm.nd.array(input_window_np, ctx=ctx)
index d1f50c8..1b38e90 100644 (file)
@@ -24,6 +24,12 @@ from tvm.contrib.pickle_memoize import memoize
 
 from common import get_all_backend
 
+_batch_matmul_implement = {
+    "generic": (topi.nn.batch_matmul, topi.generic.schedule_batch_matmul),
+    "cpu": (topi.x86.batch_matmul, topi.x86.schedule_batch_matmul),
+    "gpu": (topi.nn.batch_matmul, topi.cuda.schedule_batch_matmul),
+}
+
 def verify_batch_matmul(batch, M, N, K):
     x = tvm.placeholder((batch, M, K), name='x')
     y = tvm.placeholder((batch, N, K), name='y')
@@ -46,8 +52,9 @@ def verify_batch_matmul(batch, M, N, K):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            out = topi.nn.batch_matmul(x, y)
-            s = topi.generic.schedule_batch_matmul([out])
+            fcompute, fschedule = topi.testing.dispatch(device, _batch_matmul_implement)
+            out = fcompute(x, y)
+            s = fschedule([out])
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=dtype), ctx)
index eeaeed1..274743d 100644 (file)
@@ -35,9 +35,9 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel,
     with tvm.target.create('llvm'):
         A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_dtype, name='A')
         W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name='W')
-        B = topi.nn.bitserial_conv2d_nchw(A, W, stride, padding, activation_bits, weight_bits,
-                                          out_dtype=out_dtype, unipolar=unipolar)
-        s = topi.generic.schedule_bitserial_conv2d_nchw([B])
+        B = topi.x86.bitserial_conv2d_nchw(A, W, stride, padding, activation_bits, weight_bits,
+                                           input_dtype, out_dtype, unipolar)
+        s = topi.x86.schedule_bitserial_conv2d_nchw([B])
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -73,9 +73,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
     with tvm.target.create('llvm'):
         A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_dtype, name='A')
         W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_dtype, name='W')
-        B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
-                                          out_dtype=out_dtype, unipolar=unipolar)
-        s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
+        B = topi.x86.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
+                                           input_dtype, out_dtype, unipolar)
+        s = topi.x86.schedule_bitserial_conv2d_nhwc([B])
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
index 1b2f40d..1f87785 100644 (file)
@@ -39,9 +39,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
     with tvm.target.create(device):
         A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
         W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
-        B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
-                                          pack_dtype='uint8', out_dtype='int16', unipolar=unipolar)
-        s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
+        B = topi.arm_cpu.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits,
+                                               'uint8', out_dtype, unipolar)
+        s = topi.arm_cpu.schedule_bitserial_conv2d_nhwc([B])
 
     func = tvm.build(s, [A, W, B], device)
 
index f1bd023..505ce79 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Test code for bitserial_dense operator"""
+import os
 import numpy as np
 import tvm
 import topi
@@ -22,27 +23,21 @@ import topi.testing
 from topi.util import get_const_tuple
 from tvm.contrib.pickle_memoize import memoize
 
+_bitserial_dense_implement = {
+    "generic": (topi.nn.bitserial_dense, topi.generic.schedule_bitserial_dense),
+    "cpu": (topi.x86.bitserial_dense, topi.x86.schedule_bitserial_dense),
+    "arm_cpu": (topi.arm_cpu.bitserial_dense, topi.arm_cpu.schedule_bitserial_dense),
+}
+
 def generate_quantized_np(shape, bits, out_dtype):
     min_val = 0
     max_val = 1 << bits
     return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
 
 def verify_bitserial_dense(batch, in_dim, out_dim, activation_bits, weight_bits, unipolar):
-    input_dtype = 'uint32'
     out_dtype = 'int16'
 
-    with tvm.target.create('llvm'):
-        A = tvm.placeholder((batch, in_dim), dtype=input_dtype, name='A')
-        B = tvm.placeholder((out_dim, in_dim), dtype=input_dtype, name='B')
-        C = topi.nn.bitserial_dense(A, B, activation_bits, weight_bits, out_dtype=out_dtype,
-                                    unipolar=unipolar)
-        s = topi.generic.schedule_bitserial_dense([C])
-
-    a_shape = get_const_tuple(A.shape)
-    b_shape = get_const_tuple(B.shape)
-
-    @memoize("topi.tests.test_topi_bitseral_dense")
-    def get_ref_data():
+    def get_ref_data(a_shape, b_shape, input_dtype):
         a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
         b_np = generate_quantized_np(get_const_tuple(b_shape), weight_bits, input_dtype)
         if unipolar:
@@ -53,15 +48,30 @@ def verify_bitserial_dense(batch, in_dim, out_dim, activation_bits, weight_bits,
         else:
             c_np = np.dot(a_np, b_np.T)
         return a_np, b_np, c_np
-    a_np, b_np, c_np = get_ref_data()
 
-    ctx = tvm.cpu(0)
-    a = tvm.nd.array(a_np, ctx)
-    b = tvm.nd.array(b_np, ctx)
-    c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
-    func = tvm.build(s, [A, B, C], "llvm")
-    func(a, b, c)
-    tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+    for target in ["llvm", "llvm -device=arm_cpu"]:
+        if "arm_cpu" in target and 'arm' not in os.uname()[4]:
+            print ("Skipped running code, not an arm device")
+            continue
+        input_dtype = 'uint8' if "arm_cpu" in target else "uint32"
+        A = tvm.placeholder((batch, in_dim), dtype=input_dtype, name='A')
+        B = tvm.placeholder((out_dim, in_dim), dtype=input_dtype, name='B')
+        fcompute, fschedule = topi.testing.dispatch(target, _bitserial_dense_implement)
+        C = fcompute(A, B, activation_bits, weight_bits,
+                     input_dtype, out_dtype, unipolar)
+        s = fschedule([C])
+
+        a_shape = get_const_tuple(A.shape)
+        b_shape = get_const_tuple(B.shape)
+        a_np, b_np, c_np = get_ref_data(a_shape, b_shape, input_dtype)
+
+        ctx = tvm.cpu(0)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
+        func = tvm.build(s, [A, B, C], target)
+        func(a, b, c)
+        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
 def test_bitserial_dense():
     verify_bitserial_dense(1, 1024, 1000, 1, 1, True)
index 13da6af..ce6a286 100644 (file)
@@ -33,9 +33,9 @@ def verify_binary_dense(batch, in_dim, out_dim):
     bnn_C = topi.nn.binary_dense(bnn_A1, bnn_B1)
     # schedule
     with tvm.target.create('llvm'):
-        s1 = topi.generic.schedule_binarize_pack(bnn_A)
-        s2 = topi.generic.schedule_binarize_pack(bnn_B)
-        s3 = topi.generic.schedule_binary_dense(bnn_C)
+        s1 = topi.x86.schedule_binarize_pack(bnn_A)
+        s2 = topi.x86.schedule_binarize_pack(bnn_B)
+        s3 = topi.x86.schedule_binary_dense(bnn_C)
 
     dtype = A.dtype
     @memoize("topi.tests.test_topi_binary_dense")
index 5a0a940..2bea9b0 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """Test code for broadcasting operators."""
-from common import get_all_backend
 import numpy as np
 import tvm
 import topi
+import topi.testing
+from common import get_all_backend
 
 
 def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
@@ -33,7 +34,7 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_broadcast(B)
+            s = topi.testing.get_broadcast_schedule(device)(B)
         foo = tvm.build(s, [A, B], device, name="broadcast_to")
         data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
         out_npy = np.broadcast_to(data_npy, out_shape)
@@ -81,7 +82,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_broadcast(C)
+            s = topi.testing.get_broadcast_schedule(device)(C)
         foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + ftopi.__name__)
 
         lhs_npy, lhs_nd = gen_operand(lhs_shape, lhs_min, lhs_max, ctx)
@@ -252,7 +253,7 @@ def test_logical_single_ele():
                 return
             print("Running on target: %s" % device)
             with tvm.target.create(device):
-                s = topi.generic.schedule_broadcast(B)
+                s = topi.testing.get_broadcast_schedule(device)(B)
             foo = tvm.build(s, [A, B], device, name=name)
 
             data_npy = indata.astype(A.dtype)
@@ -293,7 +294,7 @@ def test_bitwise_not():
                 return
             print("Running on target: %s" % device)
             with tvm.target.create(device):
-                s = topi.generic.schedule_broadcast(B)
+                s = topi.testing.get_broadcast_schedule(device)(B)
             foo = tvm.build(s, [A, B], device, name=name)
 
             data_npy = np.random.uniform(size=shape).astype(A.dtype)
@@ -335,7 +336,7 @@ def test_logical_binary_ele():
                 return
             print("Running on target: %s" % device)
             with tvm.target.create(device):
-                s = topi.generic.schedule_broadcast(C)
+                s = topi.testing.get_broadcast_schedule(device)(C)
             foo = tvm.build(s, [A, B, C], device, name=name)
 
             lhs_nd = tvm.nd.array(lhs, ctx)
index 585374f..74034ce 100644 (file)
@@ -18,6 +18,7 @@
 import numpy as np
 import tvm
 import topi
+import topi.testing
 from topi.util import get_const_tuple
 from tvm.contrib.pickle_memoize import memoize
 
@@ -43,7 +44,7 @@ def verify_clip(N, a_min, a_max, dtype):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
 
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
index d54742c..6e55a57 100644 (file)
@@ -25,6 +25,18 @@ from topi.util import get_const_tuple
 from common import get_all_backend
 
 
+_conv1d_ncw_implement = {
+    "generic": (topi.nn.conv1d_ncw, topi.generic.schedule_conv1d_ncw),
+    "cpu": (topi.nn.conv1d_ncw, topi.x86.schedule_conv1d_ncw),
+    "gpu": (topi.cuda.conv1d_ncw, topi.cuda.schedule_conv1d_ncw)
+}
+
+_conv1d_nwc_implement = {
+    "generic": (topi.nn.conv1d_nwc, topi.generic.schedule_conv1d_nwc),
+    "cpu": (topi.nn.conv1d_nwc, topi.x86.schedule_conv1d_nwc),
+    "gpu": (topi.cuda.conv1d_nwc, topi.cuda.schedule_conv1d_nwc)
+}
+
 def verify_conv1d(batch,
                   in_channels,
                   in_width,
@@ -66,12 +78,13 @@ def verify_conv1d(batch,
         if not ctx.exist:
             print("Skip because %s is not enabled" % device)
             return
+        if layout == "NCW":
+            fcompute, fschedule = topi.testing.dispatch(device, _conv1d_ncw_implement)
+        else:
+            fcompute, fschedule = topi.testing.dispatch(device, _conv1d_nwc_implement)
         with tvm.target.create(device):
-            B = topi.nn.conv1d(A, W, stride, padding, dilation, layout, 'float32')
-            if layout == 'NCW':
-                s = topi.generic.schedule_conv1d_ncw([B])
-            else:
-                s = topi.generic.schedule_conv1d_nwc([B])
+            B = fcompute(A, W, stride, padding, dilation, 'float32')
+            s = fschedule([B])
 
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
index 9d6e9db..64af254 100644 (file)
@@ -24,6 +24,11 @@ from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 from common import get_all_backend
 
+_conv1d_transpose_ncw_implement = {
+    "generic": (topi.nn.conv1d_transpose_ncw, topi.generic.schedule_conv1d_transpose_ncw),
+    "gpu": (topi.cuda.conv1d_transpose_ncw, topi.cuda.schedule_conv1d_transpose_ncw)
+}
+
 def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
     in_width = in_size
     A = tvm.placeholder((batch, in_channel, in_width), name='A')
@@ -49,10 +54,11 @@ def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel,
             print("Skip because %s is not enabled" % device)
             return
         with tvm.target.create(device):
-            B = topi.nn.conv1d_transpose_ncw(A, W, stride, padding, A.dtype)
+            fcompute, fschedule = topi.testing.dispatch(device, _conv1d_transpose_ncw_implement)
+            B = fcompute(A, W, stride, padding, A.dtype)
             C = topi.nn.relu(B)
-            s1 = topi.generic.schedule_conv1d_transpose_ncw([B])
-            s2 = topi.generic.schedule_conv1d_transpose_ncw([C])
+            s1 = fschedule([B])
+            s2 = fschedule([C])
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
index af58590..8a74b4f 100644 (file)
@@ -98,16 +98,16 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), padding,
-                                     (dilation, dilation),
-                                     layout='NCHW%dc'%ic_block,
-                                     out_layout="NCHW%dc"%oc_block,
-                                     out_dtype=dtype)
+            C = topi.x86.conv2d_NCHWc(A, W, (stride, stride), padding,
+                                      (dilation, dilation),
+                                      'NCHW%dc'%ic_block,
+                                      "NCHW%dc"%oc_block,
+                                      dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
-            s = topi.generic.schedule_conv2d_NCHWc([C])
+            s = topi.x86.schedule_conv2d_NCHWc([C])
 
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
index 35423a6..086523e 100644 (file)
@@ -24,6 +24,12 @@ from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 
 
+_conv2d_hwcn_implement = {
+    "generic": (topi.nn.conv2d_hwcn, topi.generic.schedule_conv2d_hwcn),
+    "gpu": (topi.cuda.conv2d_hwcn, topi.cuda.schedule_conv2d_hwcn),
+    "opencl": (topi.cuda.conv2d_hwcn, topi.cuda.schedule_conv2d_hwcn),
+}
+
 def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
     in_height = in_width = in_size
 
@@ -56,12 +62,13 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            t_conv = topi.nn.conv2d(A, W, stride, padding, dilation, layout='HWCN')
+            fcompute, fschedule = topi.testing.dispatch(device, _conv2d_hwcn_implement)
+            t_conv = fcompute(A, W, stride, padding, dilation)
             t_bias = topi.add(t_conv, B)
             t_relu = topi.nn.relu(t_bias)
-            s1 = topi.generic.schedule_conv2d_hwcn([t_conv])
-            s2 = topi.generic.schedule_conv2d_hwcn([t_bias])
-            s3 = topi.generic.schedule_conv2d_hwcn([t_relu])
+            s1 = fschedule([t_conv])
+            s2 = fschedule([t_bias])
+            s3 = fschedule([t_relu])
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
         b = tvm.nd.array(b_np, ctx)
index 6cb66d0..c36bfa3 100644 (file)
@@ -82,13 +82,13 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
 
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            C = topi.nn.conv2d(A, W, (stride, stride), padding, (dilation, dilation),
-                               layout='NCHW', out_dtype=dtype)
+            C = topi.cuda.conv2d_NCHWc_int8(A, W, (stride, stride), padding, (dilation, dilation),
+                                            'NCHW', dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
-            s = topi.generic.schedule_conv2d_nchw([C])
+            s = topi.cuda.schedule_conv2d_NCHWc_int8([C])
 
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
index 85d733c..a0258ec 100644 (file)
@@ -66,18 +66,27 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
+
+        if "cudnn" in device:
+            fcompute, fschedule = topi.cuda.conv2d_cudnn, topi.cuda.schedule_conv2d_cudnn
+        else:
+            fcompute, fschedule = topi.testing.get_conv2d_nchw_implement(device)
+
         with tvm.target.create(device):
-            C = topi.nn.conv2d(A, W, (stride, stride), padding,
-                               (dilation, dilation), layout='NCHW', out_dtype=dtype)
+            if "cudnn" in device:
+                C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), "NCHW", dtype)
+            else:
+                C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
-            s = topi.generic.schedule_conv2d_nchw([C])
+            s = fschedule([C])
 
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
         b = tvm.nd.array(b_np, ctx)
+
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         if add_bias:
             func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
index 342f319..2a5915e 100644 (file)
@@ -24,6 +24,16 @@ from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 
 
+
+_conv2d_nhwc_implement = {
+    "generic": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc),
+    "cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc),
+    "arm_cpu": (topi.arm_cpu.conv2d_nhwc_spatial_pack,
+                topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
+    "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc)
+}
+
+
 def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
     in_height = in_width = in_size
 
@@ -60,7 +70,8 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
         func(a, w, b)
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-    for device in ['llvm', 'cuda']:
+    # TODO(@alexgl-github): add cuda back after fix conv2d_nhwc for cuda
+    for device in ['llvm']:
         check_device(device)
 
 
index 763150a..8267aad 100644 (file)
@@ -15,8 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 """Example code to do convolution."""
-import os
+import pytest
 import numpy as np
+
 import tvm
 from tvm import autotvm
 from tvm.autotvm.task.space import FallbackConfigEntity
@@ -56,7 +57,7 @@ def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, ker
 
         with tvm.target.create(device):
             B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32")
-            s = topi.generic.schedule_conv2d_nhwc_pack([B])
+            s = topi.x86.schedule_conv2d_nhwc_pack_int8([B])
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
@@ -69,22 +70,12 @@ def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, ker
         check_device(device)
 
 
-class DefaultFallback(autotvm.FallbackContext):
-    def _query_inside(self, target, workload):
-        key = (target, workload)
-        if key in self.memory:
-            return self.memory[key]
-        cfg = FallbackConfigEntity()
-        cfg.template_key = 'direct'
-        self.memory[key] = cfg
-        return cfg
-
-
+# TODO(@llyfacebook): Please fix https://github.com/apache/incubator-tvm/issues/4122 to enable this test.
+@pytest.mark.skip
 def test_conv2d_nhwc():
-    autotvm.DispatchContext.current.silent = True
-    with DefaultFallback():
-        verify_conv2d_1x1_nhwc_pack_int8(1, 256, 32, 256, 1, 1, 0)
+    verify_conv2d_1x1_nhwc_pack_int8(1, 256, 32, 256, 1, 1, 0)
 
 
 if __name__ == "__main__":
-    test_conv2d_nhwc()
+    # test_conv2d_nhwc()
+    pass
index fb836d4..e8aabc6 100644 (file)
@@ -24,6 +24,14 @@ from topi.util import get_const_tuple
 
 from common import get_all_backend
 
+_conv2d_transpose_nchw_implement = {
+    "generic": (topi.nn.conv2d_transpose_nchw, topi.generic.schedule_conv2d_transpose_nchw),
+    "cpu": (topi.x86.conv2d_transpose_nchw, topi.x86.schedule_conv2d_transpose_nchw),
+    "arm_cpu": (topi.arm_cpu.conv2d_transpose_nchw, topi.arm_cpu.schedule_conv2d_transpose_nchw),
+    "gpu": (topi.cuda.conv2d_transpose_nchw, topi.cuda.schedule_conv2d_transpose_nchw),
+    "hls": (topi.nn.conv2d_transpose_nchw, topi.hls.schedule_conv2d_transpose_nchw),
+}
+
 def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
     in_height, in_width = in_size
     kernel_height, kernel_width = kernel
@@ -54,13 +62,14 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            B = topi.nn.conv2d_transpose_nchw(A, W,
-                                              [stride_height, stride_width],
-                                              [pad_top, pad_left, pad_bottom, pad_right],
-                                              A.dtype)
+            fcompute, fschedule = topi.testing.dispatch(device, _conv2d_transpose_nchw_implement)
+            B = fcompute(A, W,
+                         [stride_height, stride_width],
+                         [pad_top, pad_left, pad_bottom, pad_right],
+                         A.dtype)
             C = topi.nn.relu(B)
-            s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
-            s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
+            s1 = fschedule([B])
+            s2 = fschedule([C])
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
index 350e621..2d12336 100644 (file)
@@ -27,6 +27,13 @@ from topi.nn.util import get_pad_tuple
 from topi.util import get_const_tuple
 
 
+_conv2d_nchw_winograd_implement = {
+    "arm_cpu": (topi.arm_cpu.conv2d_nchw_winograd, topi.arm_cpu.schedule_conv2d_nchw_winograd),
+    "cuda": (topi.cuda.conv2d_nchw_winograd, topi.cuda.schedule_conv2d_nchw_winograd),
+    "mali": (topi.mali.conv2d_nchw_winograd, topi.mali.schedule_conv2d_nchw_winograd),
+}
+
+
 def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,
         devices=['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']):
     pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
@@ -67,12 +74,13 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            C = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NCHW', out_dtype=dtype)
+            fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nchw_winograd_implement)
+            C = fcompute(A, W, stride, padding, dilation, dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
-            s = topi.generic.schedule_conv2d_nchw([C])
+            s = fschedule([C])
 
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
@@ -93,61 +101,45 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
         check_device(device)
 
 
-class WinogradFallback(autotvm.FallbackContext):
-    def _query_inside(self, target, workload):
-        key = (target, workload)
-        if key in self.memory:
-            return self.memory[key]
-        cfg = FallbackConfigEntity()
-        cfg.template_key = 'winograd'
-        self.memory[key] = cfg
-        cfg.is_fallback = False
-        return cfg
-
-
 def test_conv2d_nchw():
-    autotvm.DispatchContext.current.silent = True
-
-    with WinogradFallback():
-
-        # inception v3 workloads
-        verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3, devices=['cuda'])
-        verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3, devices=['cuda'])
-        verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3, devices=['cuda'])
-
-        # resnet 18 workloads
-        verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
-        verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
-        verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
-        verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
-        verify_conv2d_nchw(1, 48,  35, 64, 5, 1, 2, devices=['cuda'])
-
-        # batch size = 2
-        verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1)
-
-        # relu, bias
-        verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_bias=True)
-        verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_relu=True)
-        verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_relu=True, add_bias=True)
-
-        # werid workloads
-        verify_conv2d_nchw(1, 1, 1, 1, 3, 1, 1)
-        verify_conv2d_nchw(3, 3, 3, 3, 3, 1, 1)
-        verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1)
-
-        # Asymmetric padding
-        verify_conv2d_nchw(1,  48, 56,  48, 3, 1, (1, 1, 1, 1))
-        verify_conv2d_nchw(1,  64, 28,  64, 3, 1, (1, 1, 1, 1))
-        verify_conv2d_nchw(1, 128, 14, 128, 3, 1, (1, 1))
-        verify_conv2d_nchw(1, 512,  7, 512, 3, 1, "SAME")
-        verify_conv2d_nchw(2, 13,  71,  59, 3, 1, (1, 1, 1, 1))
-        verify_conv2d_nchw(2,  48, 56,  48, 3, 1, (1, 1, 1, 1), add_bias=True)
-        verify_conv2d_nchw(2,  48, 56,  48, 3, 1, (1, 1), add_relu=True)
-        verify_conv2d_nchw(2,  48, 56,  48, 3, 1, "SAME", add_relu=True, add_bias=True)
-        verify_conv2d_nchw(1,  64, 17, 192, 7, 1, (3, 1), devices=['cuda'])
-        verify_conv2d_nchw(1,  64, 17,  64, 7, 1, (3, 3, 2, 2), devices=['cuda'])
-        verify_conv2d_nchw(1, 160, 17, 160, 7, 1, "SAME", devices=['cuda'])
-        verify_conv2d_nchw(1,  48, 35,  48, 5, 1, "VALID", devices=['cuda'])
+    # inception v3 workloads
+    verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3, devices=['cuda'])
+    verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3, devices=['cuda'])
+    verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3, devices=['cuda'])
+
+    # resnet 18 workloads
+    verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
+    verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
+    verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
+    verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
+    verify_conv2d_nchw(1, 48,  35, 64, 5, 1, 2, devices=['cuda'])
+
+    # batch size = 2
+    verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1)
+
+    # relu, bias
+    verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_bias=True)
+    verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_relu=True)
+    verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_relu=True, add_bias=True)
+
+    # weird workloads
+    verify_conv2d_nchw(1, 1, 1, 1, 3, 1, 1)
+    verify_conv2d_nchw(3, 3, 3, 3, 3, 1, 1)
+    verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1)
+
+    # Asymmetric padding
+    verify_conv2d_nchw(1,  48, 56,  48, 3, 1, (1, 1, 1, 1))
+    verify_conv2d_nchw(1,  64, 28,  64, 3, 1, (1, 1, 1, 1))
+    verify_conv2d_nchw(1, 128, 14, 128, 3, 1, (1, 1))
+    verify_conv2d_nchw(1, 512,  7, 512, 3, 1, "SAME")
+    verify_conv2d_nchw(2, 13,  71,  59, 3, 1, (1, 1, 1, 1))
+    verify_conv2d_nchw(2,  48, 56,  48, 3, 1, (1, 1, 1, 1), add_bias=True)
+    verify_conv2d_nchw(2,  48, 56,  48, 3, 1, (1, 1), add_relu=True)
+    verify_conv2d_nchw(2,  48, 56,  48, 3, 1, "SAME", add_relu=True, add_bias=True)
+    verify_conv2d_nchw(1,  64, 17, 192, 7, 1, (3, 1), devices=['cuda'])
+    verify_conv2d_nchw(1,  64, 17,  64, 7, 1, (3, 3, 2, 2), devices=['cuda'])
+    verify_conv2d_nchw(1, 160, 17, 160, 7, 1, "SAME", devices=['cuda'])
+    verify_conv2d_nchw(1,  48, 35,  48, 5, 1, "VALID", devices=['cuda'])
 
 
 if __name__ == "__main__":
index 92b1068..6c60c27 100644 (file)
@@ -27,6 +27,11 @@ from topi.util import get_const_tuple
 
 from common import get_all_backend
 
+_conv3d_ncdhw_implement = {
+    "generic": (topi.nn.conv3d_ncdhw, topi.generic.schedule_conv3d_ncdhw),
+    "gpu": (topi.cuda.conv3d_ncdhw, topi.cuda.schedule_conv3d_ncdhw),
+}
+
 def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
     pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(padding, (kernel, kernel, kernel))
     padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right
@@ -65,14 +70,15 @@ def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride,
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
+        fcompute, fschedule = topi.testing.dispatch(device, _conv3d_ncdhw_implement)
         with tvm.target.create(device):
-            C = topi.nn.conv3d(A, W, (stride, stride, stride), padding,
-                               (dilation, dilation, dilation), layout='NCDHW', out_dtype=dtype)
+            C = fcompute(A, W, (stride, stride, stride), padding,
+                         (dilation, dilation, dilation), dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
-            s = topi.generic.schedule_conv3d_ncdhw([C])
+            s = fschedule([C])
 
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
index c613f68..7e2f02c 100644 (file)
@@ -23,6 +23,13 @@ import topi.testing
 from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 
+from common import get_all_backend
+
+_conv3d_ndhwc_implement = {
+    "generic": (topi.nn.conv3d_ndhwc, topi.generic.schedule_conv3d_ndhwc),
+    "cpu": (topi.x86.conv3d_ndhwc, topi.x86.schedule_conv3d_ndhwc),
+    "gpu": (topi.cuda.conv3d_ndhwc, topi.cuda.schedule_conv3d_ndhwc),
+}
 
 def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
     if isinstance(in_size, tuple):
@@ -51,13 +58,15 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
     a_np, w_np, b_np = get_ref_data()
 
     def check_device(device):
-        if not tvm.runtime.enabled(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
+        fcompute, fschedule = topi.testing.dispatch(device, _conv3d_ndhwc_implement)
         with tvm.target.create(device):
-            B = topi.nn.conv3d(A, W, stride, padding, dilation, layout="NDHWC")
-            s = topi.generic.schedule_conv3d_ndhwc([B])
+            B = fcompute(A, W, stride, padding, dilation, dtype)
+            s = fschedule([B])
         ctx = tvm.context(device, 0)
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
@@ -66,7 +75,7 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
         func(a, w, b)
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-    for device in ['llvm']:
+    for device in get_all_backend():
         check_device(device)
 
 
index 45222b6..1b1a032 100644 (file)
@@ -25,6 +25,11 @@ from topi.util import get_const_tuple
 from common import get_all_backend
 
 
+_deformable_conv2d_implement = {
+    "generic": (topi.nn.deformable_conv2d_nchw, topi.generic.schedule_deformable_conv2d_nchw),
+    "cuda": (topi.cuda.deformable_conv2d_nchw, topi.cuda.schedule_deformable_conv2d_nchw),
+}
+
 def verify_deformable_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, deformable_groups=1, groups=1):
     print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size,
             num_filter, kernel, stride, padding, dilation, deformable_groups, groups))
@@ -60,10 +65,11 @@ def verify_deformable_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
+        fcompute, fschedule = topi.testing.dispatch(device, _deformable_conv2d_implement)
         with tvm.target.create(device):
-            C = topi.nn.deformable_conv2d_nchw(A, Offset, W, stride, padding, dilation,
-                    deformable_groups, groups, out_dtype=dtype)
-            s = topi.generic.schedule_deformable_conv2d_nchw([C])
+            C = fcompute(A, Offset, W, stride, padding, dilation,
+                         deformable_groups, groups, dtype)
+            s = fschedule([C])
 
             a = tvm.nd.array(a_np, ctx)
             offset = tvm.nd.array(offset_np, ctx)
index 3b74771..d729e43 100644 (file)
@@ -24,6 +24,19 @@ from tvm.contrib.pickle_memoize import memoize
 
 from common import get_all_backend, Int8Fallback
 
+_dense_implement = {
+    "generic": [(topi.nn.dense, topi.generic.schedule_dense)],
+    "cpu": [(topi.x86.dense_nopack, topi.x86.schedule_dense_nopack),
+            (topi.x86.dense_pack, topi.x86.schedule_dense_pack)],
+    "gpu": [(topi.cuda.dense_small_batch, topi.cuda.schedule_dense_small_batch),
+            (topi.cuda.dense_large_batch, topi.cuda.schedule_dense_large_batch)],
+    "mali": [(topi.mali.dense, topi.mali.schedule_dense)],
+    "bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)],
+    "opengl": [(topi.nn.dense, topi.opengl.schedule_dense)],
+    "rocm": [(topi.rocm.dense, topi.rocm.schedule_dense)],
+    "hls": [(topi.nn.dense, topi.hls.schedule_dense)],
+}
+
 def verify_dense(batch, in_dim, out_dim, use_bias=True):
     A = tvm.placeholder((batch, in_dim), name='A')
     B = tvm.placeholder((out_dim, in_dim), name='B')
@@ -50,17 +63,18 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
-        with tvm.target.create(device):
-            D = topi.nn.dense(A, B, C if use_bias else None)
-            D = topi.nn.relu(D)
-            s = topi.generic.schedule_dense([D])
-        a = tvm.nd.array(a_np, ctx)
-        b = tvm.nd.array(b_np, ctx)
-        c = tvm.nd.array(c_np, ctx)
-        d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
-        f = tvm.build(s, [A, B, C, D], device, name="dense")
-        f(a, b, c, d)
-        tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
+        for fcompute, fschedule in topi.testing.dispatch(device, _dense_implement):
+            with tvm.target.create(device):
+                D = fcompute(A, B, C if use_bias else None)
+                D = topi.nn.relu(D)
+                s = fschedule([D])
+            a = tvm.nd.array(a_np, ctx)
+            b = tvm.nd.array(b_np, ctx)
+            c = tvm.nd.array(c_np, ctx)
+            d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
+            f = tvm.build(s, [A, B, C, D], device, name="dense")
+            f(a, b, c, d)
+            tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
 
     for device in get_all_backend():
         check_device(device)
@@ -99,9 +113,9 @@ def verify_dense_int8(batch, in_dim, out_dim, use_bias=True):
 
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            D = topi.nn.dense(A, B, C if use_bias else None, out_dtype=out_dtype)
+            D = topi.cuda.dense_int8(A, B, C if use_bias else None, out_dtype)
             D = topi.nn.relu(D)
-            s = topi.generic.schedule_dense([D])
+            s = topi.cuda.schedule_dense_int8([D])
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(c_np, ctx)
index 4e895cb..693bfb6 100644 (file)
@@ -56,7 +56,7 @@ def verify_depth_to_space(block_size, batch, in_channel, in_height, in_width, la
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
         f = tvm.build(s, [A, B], device)
index d34d56e..7efe5a2 100644 (file)
@@ -25,6 +25,24 @@ from tvm.contrib.pickle_memoize import memoize
 
 from common import get_all_backend
 
+_depthwise_conv2d_nchw_implement = {
+    "generic": [(topi.nn.depthwise_conv2d_nchw, topi.generic.schedule_depthwise_conv2d_nchw)],
+    "arm_cpu": [(topi.arm_cpu.depthwise_conv2d_nchw, topi.arm_cpu.schedule_depthwise_conv2d_nchw),
+                (topi.arm_cpu.depthwise_conv2d_nchw_spatial_pack,
+                 topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack)],
+    "gpu": [(topi.cuda.depthwise_conv2d_nchw, topi.cuda.schedule_depthwise_conv2d_nchw)],
+    "mali": [(topi.mali.depthwise_conv2d_nchw, topi.mali.schedule_depthwise_conv2d_nchw)],
+    "bifrost": [(topi.nn.depthwise_conv2d_nchw, topi.bifrost.schedule_depthwise_conv2d_nchw)],
+    "intel_graphics": [(topi.intel_graphics.depthwise_conv2d_nchw,
+                        topi.intel_graphics.schedule_depthwise_conv2d_nchw)],
+}
+
+_depthwise_conv2d_nhwc_implement = {
+    "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc),
+    "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc),
+}
+
+
 def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
     in_width = in_height
     filter_channel = in_channel
@@ -53,68 +71,75 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
-        with tvm.target.create(device):
-            # declare
-            DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter,
-                (stride_h, stride_w), padding_args, dilation, dtype)
-            ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
-            Relu = topi.nn.relu(ScaleShift)
-            # schedule
-            s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
-            s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
-            s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
-        # build the kernels
-        f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
-        f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
-        f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
-
-        # Prepare pod type for test data closure
-        input_shape = get_const_tuple(Input.shape)
-        filter_shape = get_const_tuple(Filter.shape)
-        scale_shape = get_const_tuple(Scale.shape)
-        shift_shape = get_const_tuple(Shift.shape)
-        scale_shift_shape = get_const_tuple(ScaleShift.shape)
-
-        # Use memoize, pickle the test data for next time use.
-        @memoize("topi.tests.test_topi_depthwise_conv2d.nchw")
-        def get_ref_data():
-            input_np = np.random.uniform(size=input_shape).astype(dtype)
-            filter_np = np.random.uniform(size=filter_shape).astype(dtype)
-            dilated_filter_np = topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation))
-            scale_np = np.random.uniform(size=scale_shape).astype(dtype)
-            shift_np = np.random.uniform(size=shift_shape).astype(dtype)
-            # correctness with scipy
-            depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
-                input_np, dilated_filter_np, stride, padding)
-            scale_shift_scipy = np.zeros(shape=scale_shift_shape)
-            for c in range(in_channel * channel_multiplier):
-                scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
-                relu_scipy = np.maximum(scale_shift_scipy, 0)
-            return (input_np, filter_np, scale_np, shift_np,
-                    depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy)
-        # Get the test data
-        (input_np, filter_np, scale_np, shift_np,
-         depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy) = get_ref_data()
 
-        input_tvm = tvm.nd.array(input_np, ctx)
-        filter_tvm = tvm.nd.array(filter_np, ctx)
-        scale_tvm = tvm.nd.array(scale_np, ctx)
-        shift_tvm = tvm.nd.array(shift_np, ctx)
-        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx)
-        scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
-        relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
-        # launch kernel 1 (depthwise_conv2d)
-        timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1)
-        tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
-        # launch kernel 2 (depthwise_conv2d + scale_shift)
-        timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1)
-        tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
-        # launch kernel 3 (depthwise_conv2d + scale_shift + relu)
-        timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
-        tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
-        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
-        tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
-        tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
+        impl_list = topi.testing.dispatch(device, _depthwise_conv2d_nchw_implement)[:]
+        if device == "llvm" and channel_multiplier == 1 and dilation == 1:
+            impl_list.append((topi.x86.depthwise_conv2d_nchw, topi.x86.schedule_depthwise_conv2d_nchw))
+
+        for fcompute, fschedule in impl_list:
+            with tvm.target.create(device):
+                # declare
+                DepthwiseConv2d = fcompute(Input, Filter, (stride_h, stride_w),
+                                           padding_args, dilation, dtype)
+                ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
+                Relu = topi.nn.relu(ScaleShift)
+                # schedule
+                s1 = fschedule(DepthwiseConv2d)
+                s2 = fschedule(ScaleShift)
+                s3 = fschedule(Relu)
+            # build the kernels
+            f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
+            f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
+            f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
+
+            # Prepare pod type for test data closure
+            input_shape = get_const_tuple(Input.shape)
+            filter_shape = get_const_tuple(Filter.shape)
+            scale_shape = get_const_tuple(Scale.shape)
+            shift_shape = get_const_tuple(Shift.shape)
+            scale_shift_shape = get_const_tuple(ScaleShift.shape)
+
+            # Use memoize, pickle the test data for next time use.
+            @memoize("topi.tests.test_topi_depthwise_conv2d.nchw")
+            def get_ref_data():
+                input_np = np.random.uniform(size=input_shape).astype(dtype)
+                filter_np = np.random.uniform(size=filter_shape).astype(dtype)
+                dilated_filter_np = topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation))
+                scale_np = np.random.uniform(size=scale_shape).astype(dtype)
+                shift_np = np.random.uniform(size=shift_shape).astype(dtype)
+                # correctness with scipy
+                depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
+                    input_np, dilated_filter_np, stride, padding)
+                scale_shift_scipy = np.zeros(shape=scale_shift_shape)
+                for c in range(in_channel * channel_multiplier):
+                    scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
+                    relu_scipy = np.maximum(scale_shift_scipy, 0)
+                return (input_np, filter_np, scale_np, shift_np,
+                        depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy)
+
+            # Get the test data
+            (input_np, filter_np, scale_np, shift_np,
+             depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy) = get_ref_data()
+
+            input_tvm = tvm.nd.array(input_np, ctx)
+            filter_tvm = tvm.nd.array(filter_np, ctx)
+            scale_tvm = tvm.nd.array(scale_np, ctx)
+            shift_tvm = tvm.nd.array(shift_np, ctx)
+            depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx)
+            scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
+            relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
+            # launch kernel 1 (depthwise_conv2d)
+            timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1)
+            tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
+            # launch kernel 2 (depthwise_conv2d + scale_shift)
+            timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1)
+            tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
+            # launch kernel 3 (depthwise_conv2d + scale_shift + relu)
+            timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
+            tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
+            tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
+            tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
+            tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
 
     for device in get_all_backend():
         with autotvm.tophub.context(device):  # load tophub pre-tuned parameters
@@ -150,16 +175,17 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
             return
         print("Running on target: %s" % device)
 
+        fcompute, fschedule = topi.testing.dispatch(device, _depthwise_conv2d_nhwc_implement)
         with tvm.target.create(device):
             # declare
-            DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter,
+            DepthwiseConv2d = fcompute(Input, Filter,
                 (stride_h, stride_w), padding_args, dilation, dtype)
             ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
             Relu = topi.nn.relu(ScaleShift)
             # schedule
-            s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
-            s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift)
-            s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu)
+            s1 = fschedule(DepthwiseConv2d)
+            s2 = fschedule(ScaleShift)
+            s3 = fschedule(Relu)
         # build the kernels
         f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
         f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
@@ -242,6 +268,7 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m
     stride_h = stride_w = stride
 
     assert dilation == 1, "depthwise_conv2d_NCHWc currently does not support dilation."
+    assert channel_multiplier == 1, "depthwise_conv2d_NCHWc currently does not support channel multiplier > 1."
     pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
     padding_args = (pad_h, pad_w)
 
@@ -276,17 +303,17 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m
         print("Running on target: %s" % device)
         with tvm.target.create(device):
             # declare
-            DepthwiseConv2d = topi.nn.depthwise_conv2d_NCHWc(Input, Filter,
-                                                             (stride_h, stride_w),
-                                                             padding_args,
-                                                             (dilation, dilation),
-                                                             in_layout,
-                                                             out_layout, dtype)
+            DepthwiseConv2d = topi.x86.depthwise_conv2d_NCHWc(Input, Filter,
+                                                              (stride_h, stride_w),
+                                                              padding_args,
+                                                              (dilation, dilation),
+                                                              in_layout,
+                                                              out_layout, dtype)
             # TODO: add scale_shift implement for NCHWc and add test here
             Relu = topi.nn.relu(DepthwiseConv2d)
             # schedule
-            s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
-            s2 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
+            s1 = topi.x86.schedule_depthwise_conv2d_NCHWc(DepthwiseConv2d)
+            s2 = topi.x86.schedule_depthwise_conv2d_NCHWc(Relu)
         # build the kernels
         f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
         f2 = tvm.build(s2, [Input, Filter, Relu], device)
@@ -319,7 +346,6 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m
                                                      dtype=DepthwiseConv2d.dtype), ctx)
         relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
         # launch kernel 1 (depthwise_conv2d)
-        print(filter_tvm.shape)
         f1(input_tvm, filter_tvm, depthwise_conv2d_tvm)
         # launch kernel 2 (depthwise_conv2d + relu)
         f2(input_tvm, filter_tvm, relu_tvm)
@@ -363,9 +389,7 @@ def test_depthwise_conv2d():
 
     # NCHW[x]c
     depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME")
-    depthwise_conv2d_with_workload_NCHWc(4, 256, 64, 2, 5, 2, "SAME")
     depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "VALID")
-    depthwise_conv2d_with_workload_NCHWc(4, 256, 64, 2, 5, 2, "VALID")
 
 
 if __name__ == "__main__":
index 0e17678..3904db7 100644 (file)
@@ -28,6 +28,12 @@ from topi.util import get_const_tuple
 from common import get_all_backend, Int8Fallback
 
 
+_group_conv2d_nchw_implement = {
+    "generic": (topi.nn.group_conv2d_nchw, topi.generic.schedule_group_conv2d_nchw),
+    "gpu": (topi.cuda.group_conv2d_nchw, topi.cuda.schedule_group_conv2d_nchw),
+}
+
+
 def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups, add_bias=False, add_relu=False):
     print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" %
         (batch, in_channel, in_size, num_filter,
@@ -70,12 +76,13 @@ def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, str
 
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            C = topi.nn.group_conv2d_nchw(A, W, stride, padding, dilation, groups, out_dtype=dtype)
+            fcompute, fschedule = topi.testing.dispatch(device, _group_conv2d_nchw_implement)
+            C = fcompute(A, W, stride, padding, dilation, groups, dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
-            s = topi.generic.schedule_group_conv2d_nchw([C])
+            s = fschedule([C])
 
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
@@ -149,12 +156,12 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kerne
 
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            C = topi.nn.group_conv2d_nchw(A, W, stride, padding, dilation, groups, out_dtype=dtype)
+            C = topi.cuda.group_conv2d_NCHWc_int8(A, W, stride, padding, dilation, groups, dtype)
             if add_bias:
                 C = topi.add(C, bias)
             if add_relu:
                 C = topi.nn.relu(C)
-            s = topi.generic.schedule_group_conv2d_nchw([C])
+            s = topi.cuda.schedule_group_conv2d_NCHWc_int8([C])
 
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
index 3717534..08f136e 100644 (file)
@@ -81,12 +81,12 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding),
-                                     (dilation, dilation),
-                                     layout='NCHW%dc'%ic_block,
-                                     out_layout="NCHW%dc"%oc_block,
-                                     out_dtype=dtype)
-            s = topi.generic.schedule_conv2d_NCHWc([C])
+            C = topi.x86.conv2d_NCHWc(A, W, (stride, stride), (padding, padding),
+                                      (dilation, dilation),
+                                      'NCHW%dc'%ic_block,
+                                      "NCHW%dc"%oc_block,
+                                      dtype)
+            s = topi.x86.schedule_conv2d_NCHWc([C])
 
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
index 21935cb..4297638 100644 (file)
@@ -52,7 +52,7 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
         f = tvm.build(s, [A, B], device)
@@ -116,7 +116,7 @@ def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
         f = tvm.build(s, [A, B], device)
@@ -176,7 +176,7 @@ def test_crop_and_resize():
                 return
             print("Running on target: %s" % device)
             with tvm.target.create(device):
-                s = topi.generic.schedule_injective(out)
+                s = topi.testing.get_injective_schedule(device)(out)
             tvm_images = tvm.nd.array(np_images, ctx)
             tvm_boxes = tvm.nd.array(np_boxes, ctx)
             tvm_indices = tvm.nd.array(np_box_indices, ctx)
index 53139cd..4cb3c75 100644 (file)
@@ -21,6 +21,16 @@ import topi
 from topi.util import get_const_tuple
 import topi.testing
 
+_lrn_schedule = {
+    "generic": topi.generic.schedule_lrn,
+    "gpu": topi.cuda.schedule_lrn,
+    "opencl": topi.cuda.schedule_lrn,
+    "metal": topi.cuda.schedule_lrn,
+    "rocm": topi.cuda.schedule_lrn,
+    "vulkan": topi.cuda.schedule_lrn,
+    "nvptx": topi.cuda.schedule_lrn,
+}
+
 def verify_lrn(shape, size, axis, bias, alpha, beta):
     A = tvm.placeholder(shape, name='A')
     B = topi.nn.lrn(A, size, axis, alpha, beta, bias)
@@ -35,10 +45,8 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            if device == 'llvm':
-                s = topi.generic.schedule_lrn([B])
-            else:
-                s = topi.cuda.schedule_lrn([B])
+            s_func = topi.testing.dispatch(device, _lrn_schedule)
+            s = s_func([B])
         ctx = tvm.context(device, 0)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
index 5bb95ba..debc3ef 100644 (file)
@@ -62,23 +62,15 @@ def test_ewise():
                 return
             print("Running on target: %s" % device)
             with tvm.target.create(device):
-                s = topi.generic.schedule_injective(B)
+                s = topi.testing.get_injective_schedule(device)(B)
             foo = tvm.build(s, [A, B], device, name=name)
             a = tvm.nd.array(a_np, ctx)
             b = tvm.nd.array(np.zeros_like(b_np), ctx)
             foo(a, b)
             tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
 
-        check_device('llvm')
-        check_device('cuda')
-        check_device('opencl')
-        check_device('metal')
-        check_device('rocm')
-        check_device('vulkan')
-        check_device('nvptx')
-        check_device('llvm -device=arm-cpu')
-        check_device('opencl -device=mali')
-        check_device('aocl_sw_emu')
+        for target in get_all_backend():
+            check_device(target)
 
     def test_isnan(
         low,
@@ -110,23 +102,15 @@ def test_ewise():
                 return
             print("Running on target: %s" % device)
             with tvm.target.create(device):
-                s = topi.generic.schedule_injective(B)
+                s = topi.testing.get_injective_schedule(device)(B)
             foo = tvm.build(s, [A, B], device, name="isnan")
             a = tvm.nd.array(a_np, ctx)
             b = tvm.nd.array(np.zeros_like(b_np), ctx)
             foo(a, b)
             tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
 
-        check_device('llvm')
-        check_device('cuda')
-        check_device('opencl')
-        check_device('metal')
-        check_device('rocm')
-        check_device('vulkan')
-        check_device('nvptx')
-        check_device('llvm -device=arm-cpu')
-        check_device('opencl -device=mali')
-        check_device('aocl_sw_emu')
+        for target in get_all_backend():
+            check_device(target)
 
     test_apply(topi.floor, "floor", np.floor, -100, 100)
     test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
@@ -168,7 +152,7 @@ def test_cast():
                 continue
             print("Running on target: %s" % device)
             with tvm.target.create(device):
-                s = topi.generic.schedule_injective(B)
+                s = topi.testing.get_injective_schedule(device)(B)
             foo = tvm.build(s, [A, B], device)
             a = tvm.nd.array(a_np, ctx)
             b = tvm.nd.empty(shape=shape, dtype=to_dtype, ctx=ctx)
index cb81814..084a2c7 100644 (file)
@@ -23,6 +23,25 @@ import topi.testing
 from topi.util import get_const_tuple
 from common import get_all_backend
 
+_pool_schedule = {
+    "generic": topi.generic.schedule_pool,
+    "cpu": topi.x86.schedule_pool,
+    "gpu": topi.cuda.schedule_pool,
+    "hls": topi.hls.schedule_pool,
+}
+
+_adaptive_pool_schedule = {
+    "generic": topi.generic.schedule_adaptive_pool,
+    "cpu": topi.x86.schedule_adaptive_pool,
+    "gpu": topi.cuda.schedule_adaptive_pool,
+    "hls": topi.hls.schedule_adaptive_pool,
+}
+
+_pool_grad_schedule = {
+    "generic": topi.generic.schedule_pool_grad,
+    "gpu": topi.cuda.schedule_pool_grad,
+}
+
 def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
     iw = ih
     kw = kh
@@ -74,7 +93,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_pool(B, layout)
+            s_func = topi.testing.dispatch(device, _pool_schedule)
+            s = s_func(B, layout)
 
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
@@ -129,7 +149,8 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_pool_grad(PoolGrad)
+            s_func = topi.testing.dispatch(device, _pool_grad_schedule)
+            s = s_func(PoolGrad)
 
         a = tvm.nd.array(a_np, ctx)
         out_grad = tvm.nd.array(out_grad_np, ctx)
@@ -201,7 +222,8 @@ def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_adaptive_pool(B)
+            s_func = topi.testing.dispatch(device, _adaptive_pool_schedule)
+            s = s_func(B)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
         f = tvm.build(s, [A, B], device)
@@ -255,7 +277,8 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="floa
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_adaptive_pool(out)
+            s_func = topi.testing.dispatch(device, _adaptive_pool_schedule)
+            s = s_func(out)
         a = tvm.nd.array(np_data, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(oshape), dtype=out.dtype), ctx)
         f = tvm.build(s, [data, out], device)
@@ -298,7 +321,8 @@ def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_pool(B, layout)
+            s_func = topi.testing.dispatch(device, _pool_schedule)
+            s = s_func(B, layout)
 
         a = tvm.nd.array(input_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
@@ -350,7 +374,8 @@ def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_pool(B, layout)
+            s_func = topi.testing.dispatch(device, _pool_schedule)
+            s = s_func(B, layout)
 
         a = tvm.nd.array(input_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
index d266cfc..751025b 100644 (file)
@@ -19,6 +19,7 @@ import os
 import numpy as np
 import tvm
 import topi
+import topi.testing
 
 from common import get_all_backend
 
@@ -74,7 +75,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_reduce(B)
+            s = topi.testing.get_reduce_schedule(device)(B)
 
         foo = tvm.build(s, [A, B], device, name=type)
         # Test
index 8868d4e..8ef3549 100644 (file)
@@ -19,8 +19,10 @@ import os
 import numpy as np
 import tvm
 import topi
+import topi.testing
 from topi.util import get_const_tuple
 from tvm.contrib.nvcc import have_fp16
+
 from common import get_all_backend
 
 def verify_relu(m, n, dtype="float32"):
@@ -40,7 +42,7 @@ def verify_relu(m, n, dtype="float32"):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_elemwise(B)
+            s = topi.testing.get_elemwise_schedule(device)(B)
 
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
index 4edb0a1..c4cd2b5 100644 (file)
@@ -21,6 +21,11 @@ from topi.util import get_const_tuple
 import tvm
 import topi.testing
 
+_reorg_schedule = {
+    "generic": topi.generic.schedule_reorg,
+    "gpu": topi.cuda.schedule_reorg,
+}
+
 def verify_reorg(batch, in_size, in_channel, stride):
     '''Verify reorg operator by comparing outputs from tvm and numpy implementation'''
     in_height = in_width = in_size
@@ -46,10 +51,8 @@ def verify_reorg(batch, in_size, in_channel, stride):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            if device == 'llvm':
-                s = topi.generic.schedule_reorg([B])
-            else:
-                s = topi.cuda.schedule_reorg([B])
+            s_func = topi.testing.dispatch(device, _reorg_schedule)
+            s = s_func([B])
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
         func = tvm.build(s, [A, B], device)
index 4836eef..5396b6b 100644 (file)
@@ -25,6 +25,14 @@ from topi.util import get_const_tuple
 
 from common import get_all_backend
 
+_softmax_schedule = {
+    "generic": topi.generic.schedule_softmax,
+    "cpu": topi.x86.schedule_softmax,
+    "gpu": topi.cuda.schedule_softmax,
+    "hls": topi.hls.schedule_softmax,
+    "opengl": topi.opengl.schedule_softmax,
+}
+
 def check_device(A, B, a_np, b_np, device, name):
     ctx = tvm.context(device, 0)
     if not ctx.exist:
@@ -32,11 +40,12 @@ def check_device(A, B, a_np, b_np, device, name):
         return
     print("Running on target: %s" % device)
     with tvm.target.create(device):
-        s = topi.generic.schedule_softmax(B)
+        s_func = topi.testing.dispatch(device, _softmax_schedule)
+        s = s_func(B)
 
     a = tvm.nd.array(a_np, ctx)
     b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
-    f = tvm.build(s, [A, B], device, name="softmax")
+    f = tvm.build(s, [A, B], device, name=name)
     f(a, b)
     tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
@@ -50,7 +59,7 @@ def verify_softmax(m, n, dtype="float32"):
     a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
     b_np = topi.testing.softmax_python(a_np)
 
-    for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
+    for device in get_all_backend():
         check_device(A, B, a_np, b_np, device, "softmax")
 
 def verify_softmax_4d(shape, dtype="float32"):
@@ -62,7 +71,7 @@ def verify_softmax_4d(shape, dtype="float32"):
     b_np = topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c))
     b_np = b_np.reshape(1, h, w, c).transpose(0, 3, 1, 2)
 
-    for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
+    for device in get_all_backend():
         check_device(A, B, a_np, b_np, device, "softmax")
 
 def test_softmax():
index 0ad4e98..74e55ec 100644 (file)
@@ -21,6 +21,15 @@ import tvm
 import topi
 import topi.testing
 
+_argsort_implement = {
+    "generic": (topi.argsort, topi.generic.schedule_argsort),
+    "gpu": (topi.cuda.argsort, topi.cuda.schedule_argsort),
+}
+
+_topk_implement = {
+    "generic": (topi.topk, topi.generic.schedule_topk),
+    "gpu": (topi.cuda.topk, topi.cuda.schedule_topk),
+}
 
 def verify_argsort(axis, is_ascend):
     dshape = (20, 100)
@@ -48,8 +57,9 @@ def verify_argsort(axis, is_ascend):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            out = topi.argsort(data, axis=axis, is_ascend=is_ascend)
-            s = topi.generic.schedule_argsort(out)
+            fcompute, fschedule = topi.testing.dispatch(device, _argsort_implement)
+            out = fcompute(data, axis=axis, is_ascend=is_ascend)
+            s = fschedule(out)
 
         tvm_data = tvm.nd.array(np_data, ctx)
         tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data_dtype), ctx)
@@ -91,9 +101,10 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            outs = topi.topk(data, k, axis, ret_type, is_ascend, dtype)
+            fcompute, fschedule = topi.testing.dispatch(device, _topk_implement)
+            outs = fcompute(data, k, axis, ret_type, is_ascend, dtype)
             outs = outs if isinstance(outs, list) else [outs]
-            s = topi.generic.schedule_topk(outs)
+            s = fschedule(outs)
         tvm_data = tvm.nd.array(np_data, ctx)
         tvm_res = []
         for t in outs:
index b25cad1..99a798e 100644 (file)
@@ -56,7 +56,7 @@ def verify_space_to_depth(block_size, batch, in_channel, in_height, in_width, la
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
         f = tvm.build(s, [A, B], device)
index 8e7073f..0509842 100644 (file)
@@ -18,6 +18,7 @@
 import numpy as np
 import tvm
 import topi
+import topi.testing
 from tvm.contrib.pickle_memoize import memoize
 from tvm.contrib.nvcc import have_fp16
 
@@ -98,7 +99,7 @@ def verify_vectorization(n, m, dtype):
             A = tvm.placeholder((n, m), name='A', dtype=dtype)
             B = tvm.compute((n, m), lambda i, j:
                              A[i, j] + tvm.const(1, A.dtype), name='B')
-            S = topi.generic.schedule_elemwise(B)
+            S = topi.testing.get_elemwise_schedule(device)(B)
 
             fun = tvm.build(S, [A, B], device)
             np_A = tvm.nd.empty((n, m), A.dtype, ctx).copyfrom(
index fd04fc4..880e86d 100644 (file)
@@ -33,7 +33,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_broadcast(B)
+            s = topi.testing.get_broadcast_schedule(device)(B)
         foo = tvm.build(s, [A, B], device, name="expand_dims")
         data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
         out_npy = data_npy.reshape(out_shape)
@@ -59,7 +59,7 @@ def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_elemwise(B)
+            s = topi.testing.get_elemwise_schedule(device)(B)
         foo = tvm.build(s, [A, B], device, name="reinterpret")
         data_npy = generator(in_shape).astype(in_dtype)
         out_npy = data_npy.view(B.dtype)
@@ -82,7 +82,7 @@ def verify_transpose(in_shape, axes):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         foo = tvm.build(s, [A, B], device, name="transpose")
         data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
         out_npy = data_npy.transpose(axes)
@@ -105,7 +105,7 @@ def verify_reshape(src_shape, dst_shape):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         foo = tvm.build(s, [A, B], device, name="reshape")
         data_npy = np.random.normal(size=src_shape).astype(A.dtype)
         out_npy = np.reshape(data_npy, newshape=dst_shape)
@@ -128,7 +128,7 @@ def verify_squeeze(src_shape, axis):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
 
         foo = tvm.build(s, [A, B], device, name="squeeze")
         data_npy = np.random.normal(size=src_shape).astype(A.dtype)
@@ -143,6 +143,19 @@ def verify_squeeze(src_shape, axis):
         check_device(device)
 
 def verify_concatenate(shapes, axis):
+
+    def get_concat_schedule(target):
+        schedule_map = {
+            "cpu": topi.x86.schedule_concatenate,
+            "arm_cpu": topi.arm_cpu.schedule_concatenate,
+        }
+        if isinstance(target, str):
+            target = tvm.target.create(target)
+        for key in target.keys:
+            if key in schedule_map:
+                return schedule_map[key]
+        return topi.testing.get_injective_schedule(target)
+
     tensor_l = []
     for i, shape in enumerate(shapes):
         tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
@@ -154,7 +167,7 @@ def verify_concatenate(shapes, axis):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_concatenate(out_tensor)
+            s = get_concat_schedule(device)(out_tensor)
 
         foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
         data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
@@ -179,7 +192,7 @@ def verify_stack(shapes, axis):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_broadcast(out_tensor)
+            s = topi.testing.get_broadcast_schedule(device)(out_tensor)
 
         foo = tvm.build(s, tensor_l + [out_tensor], device, name="stack")
         data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
@@ -203,7 +216,7 @@ def verify_split(src_shape, indices_or_sections, axis):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(tensor_l)
+            s = topi.testing.get_injective_schedule(device)(tensor_l)
 
         foo = tvm.build(s, [A] + list(tensor_l), device, name="split")
         data_npy = np.random.normal(size=src_shape).astype(A.dtype)
@@ -262,7 +275,7 @@ def verify_flip(in_shape, axis):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
 
         foo = tvm.build(s, [A, B], device, name="reverse")
         x_np = np.random.uniform(size=in_shape).astype(A.dtype)
@@ -293,7 +306,7 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(out_tensor)
+            s = topi.testing.get_injective_schedule(device)(out_tensor)
 
         foo = tvm.build(s, [A] + [indices] + [out_tensor] , device, name="take")
         shape_size = 1
@@ -328,7 +341,7 @@ def verify_strided_slice(in_shape, begin, end, strides=None):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
 
         foo = tvm.build(s, [A, B], device, name="stride_slice")
         x_np = np.random.uniform(size=in_shape).astype(A.dtype)
@@ -360,7 +373,7 @@ def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
 
         if strides is not None:
             foo = tvm.build(s, [A, V, b, e, st, B], device, name="stride_set")
@@ -402,7 +415,7 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(out_tensor)
+            s = topi.testing.get_injective_schedule(device)(out_tensor)
 
         func = tvm.build(s, [A, indices, out_tensor] , device, name="take")
         shape_size = 1
@@ -441,7 +454,7 @@ def verify_arange(start, stop, step):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(A)
+            s = topi.testing.get_injective_schedule(device)(A)
         f = tvm.build(s, [A], device, name="arange")
         a_nd = tvm.nd.empty(a_np.shape, dtype='float32', ctx=ctx)
         f(a_nd)
@@ -460,7 +473,7 @@ def verify_repeat(in_shape, repeats, axis):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_broadcast(B)
+            s = topi.testing.get_broadcast_schedule(device)(B)
         foo = tvm.build(s, [A, B], device, name="repeat")
         data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
         out_npy = np.repeat(data_npy, repeats, axis)
@@ -482,7 +495,7 @@ def verify_tile(in_shape, reps):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_broadcast(B)
+            s = topi.testing.get_broadcast_schedule(device)(B)
         foo = tvm.build(s, [A, B], device, name="tile")
         data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
         out_npy = np.tile(data_npy, reps)
@@ -507,7 +520,7 @@ def verify_where(in_shape):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_broadcast(C)
+            s = topi.testing.get_broadcast_schedule(device)(C)
         f = tvm.build(s, [Cond, A, B, C], device, name="where")
         cond_npy = np.random.uniform(low=-1, high=1, size=in_shape).astype(dtype)
         x_npy = np.random.uniform(size=in_shape).astype(dtype)
@@ -535,7 +548,7 @@ def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(one_hot_result)
+            s = topi.testing.get_injective_schedule(device)(one_hot_result)
         fn = tvm.build(s, [indices, one_hot_result], device, name="one_hot")
         indices_npy = np.random.randint(0, depth, size=indices_shape).astype(indices.dtype)
         out_npy = topi.testing.one_hot(indices_npy, on_value, off_value, depth, axis, dtype)
@@ -618,7 +631,7 @@ def test_squeeze():
         ctx = tvm.context(device, 0)
         if ctx.exist:
             with tvm.target.create(device):
-                s = topi.generic.schedule_injective(C)
+                s = topi.testing.get_injective_schedule(device)(C)
                 func = tvm.build(s, [A, C])
             a = tvm.nd.array(np.array((1, 2)).astype('float32'), ctx=ctx)
             c = tvm.nd.empty((1,), dtype='float32', ctx=ctx)
@@ -741,7 +754,7 @@ def test_layout_transform():
         tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=B.dtype)
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         f = tvm.build(s, [A, B], device, name="layout_transform")
         f(tvm_input, tvm_output)
         tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
@@ -768,7 +781,7 @@ def test_shape():
         tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=dtype)
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         f = tvm.build(s, [A, B], device, name="shape")
         f(tvm_input, tvm_output)
         tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
@@ -800,7 +813,7 @@ def test_sequence_mask():
                     tvm_C = tvm.nd.empty(in_shape, ctx=ctx, dtype="float32")
                     print("Running on target: %s" % device)
                     with tvm.target.create(device):
-                        s = topi.generic.schedule_injective(C)
+                        s = topi.testing.get_injective_schedule(device)(C)
                     f = tvm.build(s, [A, B, C], device, name="SequenceMask")
                     f(tvm_A, tvm_B, tvm_C)
                     tvm.testing.assert_allclose(tvm_C.asnumpy(), C_gt_data)
@@ -825,7 +838,7 @@ def test_ndarray_size():
         tvm_output = tvm.nd.empty((1,), ctx=ctx, dtype=B.dtype)
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         f = tvm.build(s, [A, B], device, name="ndarray_size")
         f(tvm_input, tvm_output)
         tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
@@ -843,9 +856,10 @@ def test_where_fusion():
                 print("Skip because %s is not enabled" % device)
                 return
             print("Running on target: %s" % device)
+            conv2d_compute, conv2d_schedule = topi.testing.get_conv2d_nchw_implement(device)
             data = tvm.placeholder((2, 1, 2, 4), 'int8', 'data')
             w = tvm.placeholder((3, 1, 2, 2), 'int8', 'w')
-            conv1 = topi.nn.conv2d(data, w, 1, 0, 1, out_dtype='int32')
+            conv1 = conv2d_compute(data, w, 1, 0, 1, 'int32')
             zeros = topi.full((2, 3, 1, 3), 'int32', tvm.const(0, dtype='int32'))
             gt = topi.greater_equal(conv1, zeros)
             one = topi.full((2, 3, 1, 3), 'int32', tvm.const(1, dtype='int32'))
@@ -853,7 +867,7 @@ def test_where_fusion():
             where = topi.where(gt, one, two)
             add = topi.add(conv1, where)
             outs = [add]
-            s = topi.generic.schedule_conv2d_nchw(outs)
+            s = conv2d_schedule(outs)
             tvm.build(s, [data, w, add], target=backend)
 
     for backend in get_all_backend():
index 875b2f7..0037487 100644 (file)
@@ -64,7 +64,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
         f = tvm.build(s, [A, B], device)
@@ -147,7 +147,7 @@ def verify_upsampling3d(batch, in_channel, in_depth, in_height, in_width, scale_
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            s = topi.generic.schedule_injective(B)
+            s = topi.testing.get_injective_schedule(device)(B)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
         f = tvm.build(s, [A, B], device)
index 85e4180..7d27b82 100644 (file)
@@ -26,6 +26,41 @@ from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 from topi.vision import ssd, non_max_suppression, get_valid_counts
 
+_get_valid_counts_implement = {
+    "generic": (topi.vision.get_valid_counts, topi.generic.schedule_get_valid_counts),
+    "gpu": (topi.cuda.get_valid_counts, topi.cuda.schedule_get_valid_counts),
+}
+
+_nms_implement = {
+    "generic": (topi.vision.non_max_suppression, topi.generic.schedule_nms),
+    "gpu": (topi.cuda.non_max_suppression, topi.cuda.schedule_nms),
+}
+
+_multibox_prior_implement = {
+    "generic": (topi.vision.ssd.multibox_prior, topi.generic.schedule_multibox_prior),
+    "gpu": (topi.cuda.multibox_prior, topi.cuda.schedule_multibox_prior),
+}
+
+_multibox_detection_implement = {
+    "generic": (topi.vision.ssd.multibox_detection, topi.generic.schedule_multibox_detection),
+    "gpu": (topi.cuda.multibox_detection, topi.cuda.schedule_multibox_detection),
+}
+
+_roi_align_implement = {
+    "generic": (topi.vision.roi_align_nchw, topi.generic.schedule_roi_align),
+    "cpu": (topi.x86.roi_align_nchw, topi.generic.schedule_roi_align),
+    "gpu": (topi.vision.roi_align_nchw, topi.cuda.schedule_roi_align),
+}
+
+_roi_pool_schedule = {
+    "generic": topi.generic.schedule_roi_pool,
+    "gpu": topi.cuda.schedule_roi_pool,
+}
+
+_proposal_implement = {
+    "generic": (topi.vision.rcnn.proposal, topi.generic.schedule_proposal),
+    "gpu": (topi.cuda.proposal, topi.cuda.schedule_proposal),
+}
 
 def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
     dtype = "float32"
@@ -54,9 +89,10 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
+            fcompute, fschedule = topi.testing.dispatch(device, _get_valid_counts_implement)
             data = tvm.placeholder(dshape, name="data", dtype=dtype)
-            outs = get_valid_counts(data, score_threshold, id_index, score_index)
-            s = topi.generic.schedule_get_valid_counts(outs)
+            outs = fcompute(data, score_threshold, id_index, score_index)
+            s = fschedule(outs)
 
         tvm_input_data = tvm.nd.array(np_data, ctx)
         tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx)
@@ -95,20 +131,14 @@ def verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_re
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            if device == 'llvm':
-                out = non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k,
-                                          coord_start=coord_start, score_index=score_index, id_index=id_index,
-                                          return_indices=False)
-                indices_out = non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k,
-                                                  coord_start=coord_start, score_index=score_index, id_index=id_index)
-            else:
-                out = topi.cuda.non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k,
-                                                    coord_start=coord_start, score_index=score_index, id_index=id_index,
-                                                    return_indices=False)
-                indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k,
-                                                            coord_start=coord_start, score_index=score_index, id_index=id_index)
-            s = topi.generic.schedule_nms(out)
-            indices_s = topi.generic.schedule_nms(indices_out)
+            fcompute, fschedule = topi.testing.dispatch(device, _nms_implement)
+            out = fcompute(data, valid_count, -1, iou_threshold, force_suppress, top_k,
+                           coord_start=coord_start, score_index=score_index, id_index=id_index,
+                           return_indices=False)
+            indices_out = fcompute(data, valid_count, -1, iou_threshold, force_suppress, top_k,
+                                   coord_start=coord_start, score_index=score_index, id_index=id_index)
+            s = fschedule(out)
+            indices_s = fschedule(indices_out)
 
         tvm_data = tvm.nd.array(np_data, ctx)
         tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
@@ -193,12 +223,11 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offse
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
+        
+        fcompute, fschedule = topi.testing.dispatch(device, _multibox_prior_implement)
         with tvm.target.create(device):
-            if device == 'llvm':
-                out = ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip)
-            else:
-                out = topi.cuda.ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip)
-            s = topi.generic.schedule_multibox_prior(out)
+            out = fcompute(data, sizes, ratios, steps, offsets, clip)
+            s = fschedule(out)
 
         tvm_input_data = tvm.nd.array(input_data, ctx)
         tvm_out = tvm.nd.array(np.zeros(oshape, dtype=dtype), ctx)
@@ -239,12 +268,11 @@ def test_multibox_detection():
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
+
+        fcompute, fschedule = topi.testing.dispatch(device, _multibox_detection_implement)
         with tvm.target.create(device):
-            if device == 'llvm':
-                out = ssd.multibox_detection(cls_prob, loc_preds, anchors)
-            else:
-                out = topi.cuda.ssd.multibox_detection(cls_prob, loc_preds, anchors)
-            s = topi.generic.schedule_multibox_detection(out)
+            out = fcompute(cls_prob, loc_preds, anchors)
+            s = fschedule(out)
 
         tvm_cls_prob = tvm.nd.array(np_cls_prob.astype(cls_prob.dtype), ctx)
         tvm_loc_preds = tvm.nd.array(np_loc_preds.astype(loc_preds.dtype), ctx)
@@ -286,10 +314,11 @@ def verify_roi_align(batch, in_channel, in_size, num_roi, pooled_size, spatial_s
         print("Running on target: %s" % device)
 
         with tvm.target.create(device):
-            b = topi.vision.rcnn.roi_align_nchw(a, rois, pooled_size=pooled_size,
-                                                spatial_scale=spatial_scale,
-                                                sample_ratio=sample_ratio)
-            s = topi.generic.schedule_roi_align(b)
+            fcompute, fschedule = topi.testing.dispatch(device, _roi_align_implement)
+            b = fcompute(a, rois, pooled_size=pooled_size,
+                         spatial_scale=spatial_scale,
+                         sample_ratio=sample_ratio)
+            s = fschedule(b)
 
         tvm_a = tvm.nd.array(a_np, ctx)
         tvm_rois = tvm.nd.array(rois_np, ctx)
@@ -338,7 +367,8 @@ def verify_roi_pool(batch, in_channel, in_size, num_roi, pooled_size, spatial_sc
         with tvm.target.create(device):
             b = topi.vision.rcnn.roi_pool_nchw(a, rois, pooled_size=pooled_size,
                                                 spatial_scale=spatial_scale)
-            s = topi.generic.schedule_roi_pool(b)
+            s_func = topi.testing.dispatch(device, _roi_pool_schedule)
+            s = s_func(b)
 
         tvm_a = tvm.nd.array(a_np, ctx)
         tvm_rois = tvm.nd.array(rois_np, ctx)
@@ -368,8 +398,9 @@ def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            out = topi.vision.proposal(cls_prob, bbox_pred, im_info, **attrs)
-            s = topi.generic.schedule_proposal(out)
+            fcompute, fschedule = topi.testing.dispatch(device, _proposal_implement)
+            out = fcompute(cls_prob, bbox_pred, im_info, **attrs)
+            s = fschedule(out)
             f = tvm.build(s, [cls_prob, bbox_pred, im_info, out], device)
             tvm_cls_prob = tvm.nd.array(np_cls_prob, ctx=ctx)
             tvm_bbox_pred = tvm.nd.array(np_bbox_pred, ctx=ctx)
@@ -428,4 +459,5 @@ if __name__ == "__main__":
     test_multibox_prior()
     test_multibox_detection()
     test_roi_align()
+    test_roi_pool()
     test_proposal()
index 09b5604..0e26dcb 100644 (file)
@@ -78,7 +78,7 @@ from tvm import autotvm
 # can be very large (at the level of 10^9 for some input shapes)
 #
 
-@autotvm.template
+@autotvm.register_customized_task("tutorial/conv2d_no_batching")
 def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
     assert N == 1, "Only consider batch_size = 1 in this template"
 
@@ -180,7 +180,7 @@ logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
 
 # the last layer in resnet
 N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
-task = autotvm.task.create(conv2d_no_batching,
+task = autotvm.task.create("tutorial/conv2d_no_batching",
                            args=(N, H, W, CO, CI, KH, KW, strides, padding),
                            target='cuda')
 print(task.config_space)
index 4cbdf52..9aba937 100644 (file)
@@ -248,30 +248,7 @@ def tune_tasks(tasks,
                n_trial=1000,
                early_stopping=None,
                log_filename='tuning.log',
-               use_transfer_learning=True,
-               try_winograd=True,
-               try_spatial_pack_depthwise=False):
-    if try_winograd:
-        for i in range(len(tasks)):
-            try:  # try winograd template
-                tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
-                                          tasks[i].target, tasks[i].target_host, 'winograd')
-                input_channel = tsk.workload[1][1]
-                if input_channel >= 64:
-                    tasks[i] = tsk
-            except Exception:
-                pass
-
-    # if we want to use spatial pack for depthwise convolution
-    if try_spatial_pack_depthwise:
-        tuner = 'xgb_knob'
-        for i in range(len(tasks)):
-            if tasks[i].name == 'topi_nn_depthwise_conv2d_nchw':
-                tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
-                                          tasks[i].target, tasks[i].target_host,
-                                          'contrib_spatial_pack')
-                tasks[i] = tsk
-
+               use_transfer_learning=True):
     # create tmp log file
     tmp_log_file = log_filename + ".tmp"
     if os.path.exists(tmp_log_file):
@@ -322,7 +299,7 @@ def tune_and_evaluate(tuning_opt):
     mod, params, input_shape, _ = get_network(network, batch_size=1)
     tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                               params=params,
-                                              ops=(relay.op.nn.conv2d,))
+                                              ops=(relay.op.get("nn.conv2d"),))
 
     # run tuning tasks
     print("Tuning...")
index 72fc2be..58c8751 100644 (file)
@@ -163,19 +163,7 @@ def tune_tasks(tasks,
                n_trial=1000,
                early_stopping=None,
                log_filename='tuning.log',
-               use_transfer_learning=True,
-               try_winograd=True):
-    if try_winograd:
-        for i in range(len(tasks)):
-            try:  # try winograd template
-                tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
-                                          tasks[i].target, tasks[i].target_host, 'winograd')
-                input_channel = tsk.workload[1][1]
-                if input_channel >= 64:
-                    tasks[i] = tsk
-            except Exception:
-                pass
-
+               use_transfer_learning=True):
     # create tmp log file
     tmp_log_file = log_filename + ".tmp"
     if os.path.exists(tmp_log_file):
@@ -223,7 +211,8 @@ def tune_and_evaluate(tuning_opt):
     print("Extract tasks...")
     mod, params, input_shape, out_shape = get_network(network, batch_size=1)
     tasks = autotvm.task.extract_from_program(mod["main"], target=target,
-                                              params=params, ops=(relay.op.nn.conv2d,))
+                                              params=params,
+                                              ops=(relay.op.get("nn.conv2d"),))
 
     # run tuning tasks
     print("Tuning...")
index 3c56524..5425f1b 100644 (file)
@@ -247,17 +247,7 @@ def tune_tasks(tasks,
                n_trial=1000,
                early_stopping=None,
                log_filename='tuning.log',
-               use_transfer_learning=True,
-               try_winograd=True):
-    if try_winograd:
-        for i in range(len(tasks)):
-            try:  # try winograd template
-                tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
-                                          tasks[i].target, tasks[i].target_host, 'winograd')
-                tasks.append(tsk)
-            except Exception:
-                pass
-
+               use_transfer_learning=True):
     # create tmp log file
     tmp_log_file = log_filename + ".tmp"
     if os.path.exists(tmp_log_file):
@@ -307,7 +297,8 @@ def tune_and_evaluate(tuning_opt):
     tasks = autotvm.task.extract_from_program(mod["main"],
                                               target=target,
                                               target_host=target_host,
-                                              params=params, ops=(relay.op.nn.conv2d,))
+                                              params=params,
+                                              ops=(relay.op.get("nn.conv2d"),))
 
     # run tuning tasks
     print("Tuning...")
index 5e26f58..87d07f9 100644 (file)
@@ -132,22 +132,9 @@ def tune_kernels(tasks,
                  early_stopping=None,
                  log_filename='tuning.log'):
 
-    for i, tsk in enumerate(tasks):
+    for i, task in enumerate(tasks):
         prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
 
-        # converting conv2d tasks to conv2d_NCHWc tasks
-        op_name = tsk.workload[0]
-        if op_name == 'conv2d':
-            func_create = 'topi_x86_conv2d_NCHWc'
-        elif op_name == 'depthwise_conv2d_nchw':
-            func_create = 'topi_x86_depthwise_conv2d_NCHWc_from_nchw'
-        else:
-            raise ValueError("Tuning {} is not supported on x86".format(op_name))
-
-        task = autotvm.task.create(func_create, args=tsk.args,
-                                   target=target, template_key='direct')
-        task.workload = tsk.workload
-
         # create tuner
         if tuner == 'xgb' or tuner == 'xgb-rank':
             tuner_obj = XGBTuner(task, loss_type='rank')
@@ -173,7 +160,7 @@ def tune_kernels(tasks,
 # Use graph tuner to achieve graph level optimal schedules
 # Set use_DP=False if it takes too long to finish.
 def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
-    target_op = [relay.nn.conv2d]
+    target_op = [relay.op.get("nn.conv2d"),]
     Tuner = DPTuner if use_DP else PBQPTuner
     executor = Tuner(graph, {input_name: dshape}, records, target_op, target)
     executor.benchmark_layout_transform(min_exec_num=2000)
@@ -189,10 +176,10 @@ def tune_and_evaluate(tuning_opt):
     print("Extract tasks...")
     mod, params, data_shape, out_shape = get_network(model_name, batch_size)
     tasks = autotvm.task.extract_from_program(mod["main"], target=target,
-                                              params=params, ops=(relay.op.nn.conv2d,))
+                                              params=params,
+                                              ops=(relay.op.get("nn.conv2d"),))
 
     # run tuning tasks
-    print("Tuning...")
     tune_kernels(tasks, **tuning_opt)
     tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file)
 
index b6ad7e9..8efeed4 100644 (file)
@@ -102,7 +102,7 @@ def matmul_v0(N, L, M, dtype):
 # In autotvm, we can define a tunable parameter, or a "knob" for such kind of value.
 
 # Matmul V1: List candidate values
-@autotvm.template  # 1. use a decorator
+@autotvm.register_customized_task("tutorial/matmul_v1")  # 1. use a decorator
 def matmul_v1(N, L, M, dtype):
     A = tvm.placeholder((N, L), name='A', dtype=dtype)
     B = tvm.placeholder((L, M), name='B', dtype=dtype)
@@ -182,7 +182,7 @@ def matmul_v1(N, L, M, dtype):
 # When the high level API cannot meet your requirement, you can always fall
 # back to use low level API.
 
-@autotvm.template
+@autotvm.register_customized_task("tutorial/matmul")
 def matmul(N, L, M, dtype):
     A = tvm.placeholder((N, L), name='A', dtype=dtype)
     B = tvm.placeholder((L, M), name='B', dtype=dtype)
@@ -272,7 +272,7 @@ def matmul(N, L, M, dtype):
 # In this case, for a 512x512 square matrix multiplication, the space size
 # is 10x10=100
 N, L, M = 512, 512, 512
-task = autotvm.task.create(matmul, args=(N, L, M, 'float32'), target='llvm')
+task = autotvm.task.create("tutorial/matmul", args=(N, L, M, 'float32'), target='llvm')
 print(task.config_space)
 
 ################################################################
index d27e236..494593e 100644 (file)
@@ -78,7 +78,7 @@ def example():
 # the scope of this tutorial.
 
 @relay.op.register_alter_op_layout("nn.conv2d", level=101)
-def alter_conv2d(attrs, inputs, tinfos):
+def alter_conv2d(attrs, inputs, tinfos, out_type):
     data, weight = inputs
     new_attrs = dict(attrs)
     new_attrs['data_layout'] = 'NCHW16c'
@@ -245,10 +245,10 @@ print(mod3)
 f = example()
 mod = tvm.IRModule.from_expr(f)
 seq = relay.transform.Sequential([relay.transform.FoldConstant(),
-                                  relay.transform.PrintIR(),
+                                  relay.transform.PrintIR(False),
                                   relay.transform.EliminateCommonSubexpr(),
                                   relay.transform.FuseOps(),
-                                  relay.transform.PrintIR()])
+                                  relay.transform.PrintIR(False)])
 with relay.build_config(opt_level=3):
     mod = seq(mod)
 
index f7cdae2..a4658eb 100644 (file)
@@ -94,7 +94,7 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
 #
 # We use AutoTVM to search for best configurations in this schedule.
 
-@autotvm.template
+@autotvm.register_customized_task("tutorial/test_gemm")
 def test_gemm(N, L, M, dtype, layout):
     if (layout == "NN"):
       shape_a = (N, L)
@@ -264,7 +264,8 @@ elif dtype == 'int4' or dtype == 'int1':
   assert(major == 7 and minor == 5 and layout == 'TN')
 
 def tune_and_evaluate(M, N, L, dtype, layout):
-  task = autotvm.task.create(test_gemm, args=(N, L, M, dtype, layout), target='cuda')
+  task = autotvm.task.create("tutorial/test_gemm", args=(N, L, M, dtype, layout),
+                             target='cuda')
   print(task.config_space)
 
   logging.getLogger('autotvm').setLevel(logging.DEBUG)
index 390085e..2e04982 100644 (file)
@@ -85,7 +85,7 @@ e = topi.elemwise_sum([c, d])
 f = e / 2.0
 g = topi.sum(f)
 with tvm.target.cuda():
-    sg = topi.generic.schedule_reduce(g)
+    sg = topi.cuda.schedule_reduce(g)
     print(tvm.lower(sg, [a, b], simple_mode=True))
 
 ######################################################################
@@ -113,7 +113,7 @@ tvm.testing.assert_allclose(g_nd.asnumpy(), g_np, rtol=1e-5)
 tarray = tvm.placeholder((512, 512), name="tarray")
 softmax_topi = topi.nn.softmax(tarray)
 with tvm.target.create("cuda"):
-    sst = topi.generic.schedule_softmax(softmax_topi)
+    sst = topi.cuda.schedule_softmax(softmax_topi)
     print(tvm.lower(sst, [tarray], simple_mode=True))
 
 ######################################################################
@@ -133,9 +133,9 @@ data = tvm.placeholder((1, 3, 224, 224))
 kernel = tvm.placeholder((10, 3, 5, 5))
 
 with tvm.target.create("cuda"):
-    conv = topi.nn.conv2d(data, kernel, strides=1, padding=2, dilation=1)
+    conv = topi.cuda.conv2d_nchw(data, kernel, 1, 2, 1)
     out = topi.nn.relu(conv)
-    sconv = topi.generic.nn.schedule_conv2d_nchw([out])
+    sconv = topi.cuda.schedule_conv2d_nchw([out])
     print(tvm.lower(sconv, [data, kernel], simple_mode=True))
 
 ######################################################################
index 36d8e41..0c9b2ea 100644 (file)
@@ -662,8 +662,12 @@ def inject_conv2d_transpose_skip(stmt_in):
                                          0, 0,
                                          0, 0, 0))
                 inner = irb.get()
-                args = op.body.body.args
-                res_tensor = op.body.body.func.output(0)
+                # TODO(@tmoreau89): This is only a temporary fix, please take a look.
+                body = op.body.body
+                while isinstance(body, tvm.stmt.IfThenElse):
+                    body = body.then_case
+                args = body.args
+                res_tensor = body.func.output(0)
                 tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
                 inner = tvm.tir.AttrStmt(
                     [dout, res_tensor], 'buffer_bind_scope',
index 7fdf27f..6f62aff 100644 (file)
@@ -20,8 +20,8 @@
 from . import bitpack
 from .graphpack import graph_pack
 from . import op
-from . import vta_conv2d
-from . import vta_conv2d_transpose
-from . import vta_group_conv2d
-from . import vta_dense
+from .vta_conv2d import conv2d_packed, schedule_conv2d_packed
+from .vta_conv2d_transpose import conv2d_transpose_packed, schedule_conv2d_transpose_packed
+from .vta_group_conv2d import group_conv2d_packed, schedule_group_conv2d_packed
+from .vta_dense import dense_packed, schedule_dense_packed
 from . import util
index d4748fa..6e9d57b 100644 (file)
@@ -22,9 +22,8 @@ from __future__ import absolute_import as _abs
 import tvm
 from topi import util
 
-from tvm.relay.op.op import register_compute, register_schedule
+from tvm.relay.op.op import register_compute, register_injective_schedule
 from tvm.relay.op.op import register_pattern, OpPattern
-from tvm.relay.op.op import schedule_injective
 
 def bitpack(data, bits, pack_type="int8", name="bitpack"):
     """Packs lowest dimension into format needed by VTA
@@ -86,5 +85,5 @@ def compute_bitpack(attrs, inputs):
     bits = 8 // lanes
     return bitpack(inputs[0], bits, dtype)
 
-register_schedule("bitpack", schedule_injective)
+register_injective_schedule("bitpack")
 register_pattern("bitpack", OpPattern.INJECTIVE)
index bf6409c..04e14b1 100644 (file)
@@ -22,19 +22,22 @@ import tvm
 import topi
 
 from tvm.relay.op import op as reg
-from tvm.relay.op.op import OpPattern
-from tvm.relay.op.nn import _nn
+from tvm.relay.op import strategy as _strategy
+from tvm.relay.op.op import OpPattern, OpStrategy
 
 from .util import is_packed_layout
+from .vta_conv2d import conv2d_packed, schedule_conv2d_packed
+from .vta_conv2d_transpose import conv2d_transpose_packed, schedule_conv2d_transpose_packed
+from .vta_group_conv2d import group_conv2d_packed, schedule_group_conv2d_packed
+from .vta_dense import dense_packed, schedule_dense_packed
 from ..environment import get_env
 
 
 # override to force partition at copy
 reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
 
-
-@reg.register_compute("clip", level=15)
-def compute_clip(attrs, inputs, output_type, target):
+# add clip vta strategy
+def compute_clip_vta(attrs, inputs, output_type):
     """ Clip operator. """
     x = inputs[0]
     a_min = attrs.a_min
@@ -48,139 +51,79 @@ def compute_clip(attrs, inputs, output_type, target):
             x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
     return [x]
 
-
-@reg.register_compute("nn.conv2d", level=15)
-def compute_conv2d(attrs, inputs, output_type, target):
-    """ Compute definition of conv2d """
-    padding = topi.util.get_const_tuple(attrs.padding)
-    strides = topi.util.get_const_tuple(attrs.strides)
-    dilation = tuple([int(d) for d in attrs.dilation])
+def clip_strategy_vta(attrs, inputs, out_type, target):
+    strategy = OpStrategy()
+    strategy.add_implementation(
+        compute_clip_vta,
+        _strategy.wrap_topi_schedule(topi.generic.schedule_injective),
+        name="clip.vta")
+    return strategy
+
+reg.get("clip").get_attr("FTVMStrategy").register(clip_strategy_vta, "vta")
+
+@_strategy.conv2d_strategy.register("vta")
+def conv2d_strategy_vta(attrs, inputs, out_type, target):
+    """conv2d vta strategy"""
+    strategy = OpStrategy()
+    kernel = inputs[1]
+    dilation = topi.util.get_const_tuple(attrs.dilation)
     groups = attrs.groups
     layout = attrs.data_layout
-    out_dtype = attrs.out_dtype
-
-    if target.device_name == "vta":
-        assert dilation == (1, 1), "support for dilation limited to (1, 1)"
-        if is_packed_layout(layout):
-            if groups == 1:
-                assert groups == 1
-                env = get_env()
-                assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
-                assert env.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now"
-                inputs = list(inputs)
-                assert inputs[1].dtype == "int8"
-                return [topi.nn.conv2d(inputs[0],
-                                       inputs[1],
-                                       strides,
-                                       padding,
-                                       dilation,
-                                       layout,
-                                       out_dtype)]
-            return [topi.nn.group_conv2d_nchw(inputs[0],
-                                              inputs[1],
-                                              strides,
-                                              padding,
-                                              dilation,
-                                              groups,
-                                              out_dtype)]
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.compute_conv2d(attrs, inputs, output_type, target)
-
-    # If VTA is not the target, default to _nn def
-    return _nn.compute_conv2d(attrs, inputs, output_type, target)
-
-
-@reg.register_schedule("nn.conv2d", level=15)
-def schedule_conv2d(attrs, outs, target):
-    """ Schedule definition of conv2d """
-    groups = attrs.groups
-    layout = attrs.data_layout
-
-    if target.device_name == "vta":
-        if is_packed_layout(layout):
-            target = tvm.target.create(target)
-            assert target.device_name == "vta"
-            if groups == 1:
-                return topi.generic.schedule_conv2d_nchw(outs)
-            return topi.generic.schedule_group_conv2d_nchw(outs)
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.schedule_conv2d(attrs, outs, tvm.target.Target.current())
-
-    # If VTA is not the target, default to _nn def
-    return _nn.schedule_conv2d(attrs, outs, target)
-
-
-@reg.register_compute("nn.conv2d_transpose", level=15)
-def compute_conv2d_transpose(attrs, inputs, output_type, target):
-    """ 2D convolution algorithm.
-    """
-    padding = topi.util.get_const_tuple(attrs.padding)
-    strides = topi.util.get_const_tuple(attrs.strides)
-    dilation = tuple([int(d) for d in attrs.dilation])
-    layout = attrs.data_layout
-    out_dtype = attrs.out_dtype
-
-    if target.device_name == "vta":
-        assert dilation == (1, 1), "support for dilation limited to (1, 1)"
-        if is_packed_layout(layout):
-            return [topi.nn.conv2d_transpose_nchw(
-                inputs[0], inputs[1], strides, padding, out_dtype)]
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
-
-    # If VTA is not the target, default to _nn def
-    return _nn.compute_conv2d_transpose(attrs, inputs, output_type, target)
 
-
-@reg.register_schedule("nn.conv2d_transpose", level=15)
-def schedule_conv2d_transpose(attrs, outputs, target):
-    """ 2D convolution schedule.
-    """
+    assert dilation == (1, 1), "support for dilation limited to (1, 1)"
+    if is_packed_layout(layout):
+        if groups == 1:
+            env = get_env()
+            assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
+            assert env.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now"
+            assert kernel.dtype == "int8"
+
+            strategy.add_implementation(
+                _strategy.wrap_compute_conv2d(conv2d_packed, True),
+                _strategy.wrap_topi_schedule(schedule_conv2d_packed),
+                name="conv2d_packed.vta")
+        else: # group_conv2d
+            strategy.add_implementation(
+                _strategy.wrap_compute_conv2d(group_conv2d_packed, has_groups=True),
+                _strategy.wrap_topi_schedule(schedule_group_conv2d_packed),
+                name="group_conv2d_packed.vta")
+        return strategy
+
+    # If it's not packed, run on ARM CPU
+    arm_tgt = tvm.target.arm_cpu(target.model)
+    return _strategy.arm_cpu.conv2d_strategy_arm_cpu(attrs, inputs, out_type, arm_tgt)
+
+
+@_strategy.conv2d_transpose_strategy.register("vta")
+def conv2d_transpose_strategy_vta(attrs, inputs, out_type, target):
+    """conv2d_transpose vta strategy"""
+    dilation = topi.util.get_const_tuple(attrs.dilation)
     layout = attrs.data_layout
-
-    if target.device_name == "vta":
-        if is_packed_layout(layout):
-            return topi.nn.schedule_conv2d_transpose_nchw(outputs)
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current())
-
-    # If VTA is not the target, default to _nn def
-    return _nn.schedule_conv2d_transpose(attrs, outputs, tvm.target.Target.current())
-
-
-@reg.register_compute("nn.dense", level=15)
-def compute_dense(attrs, inputs, out_type, target):
-    """Compute definition of dense"""
-    out_dtype = attrs.out_dtype
-    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
-
-    if target.device_name == "vta":
-        if inputs[0].shape == 4: # this implies the layout is packed
-            target = tvm.target.create(target)
-            return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.compute_dense(attrs, inputs, out_type, target)
-
-    # If VTA is not the target, default to _nn def
-    return _nn.compute_dense(attrs, inputs, out_type, target)
-
-
-@reg.register_schedule("nn.dense", level=15)
-def schedule_dense(attrs, outs, target):
-    """Schedule definition of dense"""
-    if target.device_name == "vta":
-        if outs[0].shape == 4: # this implies the layout is packed
-            target = tvm.target.create(target)
-            assert target.device_name == "vta"
-            return topi.generic.schedule_dense(outs)
-        # If it's not packed, run on ARM CPU
-        with tvm.target.arm_cpu(tvm.target.Target.current().model):
-            return _nn.schedule_dense(attrs, outs, tvm.target.Target.current())
-
-    # If VTA is not the target, default to _nn def
-    return _nn.schedule_dense(attrs, outs, target)
+    assert dilation == (1, 1), "support for dilation limited to (1, 1)"
+
+    if is_packed_layout(layout):
+        strategy = OpStrategy()
+        strategy.add_implementation(
+            _strategy.wrap_compute_conv2d_transpose(conv2d_transpose_packed),
+            _strategy.wrap_topi_schedule(schedule_conv2d_transpose_packed),
+            name="conv2d_transpose_packed.vta")
+        return strategy
+
+    # If it's not packed, run on ARM CPU
+    arm_tgt = tvm.target.arm_cpu(target.model)
+    return _strategy.arm_cpu.conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, arm_tgt)
+
+
+@_strategy.dense_strategy.register("vta")
+def dense_strategy_vta(attrs, inputs, out_type, target):
+    """dense vta strategy"""
+    if inputs[0].shape == 4: # this implies the layout is packed
+        strategy = OpStrategy()
+        strategy.add_implementation(
+            _strategy.wrap_compute_dense(dense_packed),
+            _strategy.wrap_topi_schedule(schedule_dense_packed),
+            name="dense_packed.vta")
+        return strategy
+    # If it's not packed, run on ARM CPU
+    arm_tgt = tvm.target.arm_cpu(target.model)
+    return _strategy.x86.dense_strategy_cpu(attrs, inputs, out_type, arm_tgt)
index e15f6c1..ba93b05 100644 (file)
@@ -25,15 +25,8 @@ import topi
 from .util import is_packed_layout
 from ..environment import get_env
 
-@autotvm.register_topi_compute(topi.nn.conv2d, 'vta', 'direct')
-def _declaration_conv2d(cfg,
-                        data,
-                        kernel,
-                        strides,
-                        padding,
-                        dilation,
-                        layout,
-                        out_dtype):
+@autotvm.register_topi_compute("conv2d_packed.vta")
+def conv2d_packed(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
     """ Packed conv2d function."""
     if not is_packed_layout(layout):
         raise topi.InvalidShapeError()
@@ -69,8 +62,9 @@ def _declaration_conv2d(cfg,
 
     return res
 
-@autotvm.register_topi_schedule(topi.generic.schedule_conv2d_nchw, 'vta', 'direct')
-def _schedule_conv2d(cfg, outs):
+@autotvm.register_topi_schedule("conv2d_packed.vta")
+def schedule_conv2d_packed(cfg, outs):
+    """Schedule packed conv2d"""
     assert len(outs) == 1
     output = outs[0]
     const_ops = []
index a2750dc..a3fd7ac 100644 (file)
@@ -26,13 +26,9 @@ from topi.nn.util import get_pad_tuple
 
 from ..environment import get_env
 
-@autotvm.register_topi_compute(topi.nn.conv2d_transpose_nchw, 'vta', 'direct')
-def _declatation_conv2d_transpose(cfg,
-                                  data,
-                                  kernel,
-                                  strides,
-                                  padding,
-                                  out_dtype):
+@autotvm.register_topi_compute("conv2d_transpose_packed.vta")
+def conv2d_transpose_packed(cfg, data, kernel, strides, padding, out_dtype):
+    """Packed conv2d_transpose compute"""
     ishape = get_const_tuple(data.shape)
     kshape = get_const_tuple(kernel.shape)
     b, c_i, i_h, i_w, t_b, t_ci = ishape
@@ -75,8 +71,9 @@ def _declatation_conv2d_transpose(cfg,
 
     return out
 
-@autotvm.register_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw, 'vta', 'direct')
-def _schedule_conv2d_transpose(cfg, outs):
+@autotvm.register_topi_schedule("conv2d_transpose_packed.vta")
+def schedule_conv2d_transpose_packed(cfg, outs):
+    """Schedule packed conv2d_transpose"""
     assert len(outs) == 1
     output = outs[0]
     ewise_inputs = []
index 9d6c19c..e239104 100644 (file)
@@ -32,12 +32,8 @@ def is_packed_layout(layout):
         return True
     return False
 
-@autotvm.register_topi_compute(topi.nn.dense, 'vta', 'direct')
-def _declaration_dense(cfg,
-                       data,
-                       weight,
-                       bias=None,
-                       out_dtype=None):
+@autotvm.register_topi_compute("dense_packed.vta")
+def dense_packed(cfg, data, weight, bias=None, out_dtype=None):
     """Dense function declaration."""
 
     # Make sure that the dense operator is packed
@@ -67,8 +63,8 @@ def _declaration_dense(cfg,
 
     return res
 
-@autotvm.register_topi_schedule(topi.generic.schedule_dense, 'vta', 'direct')
-def _schedule_dense(cfg, outs):
+@autotvm.register_topi_schedule("dense_packed.vta")
+def schedule_dense_packed(cfg, outs):
     """Packed dense schedule."""
 
     assert len(outs) == 1
index e54637f..aa06c61 100644 (file)
@@ -24,8 +24,8 @@ import topi
 
 from ..environment import get_env
 
-@autotvm.register_topi_compute(topi.nn.group_conv2d_nchw, 'vta', 'direct')
-def packed_group_conv2d(cfg,
+@autotvm.register_topi_compute("group_conv2d_packed.vta")
+def group_conv2d_packed(cfg,
                         data,
                         kernel,
                         strides,
@@ -74,8 +74,8 @@ def packed_group_conv2d(cfg,
     return out
 
 
-@autotvm.register_topi_schedule(topi.generic.schedule_group_conv2d_nchw, 'vta', 'direct')
-def schedule_packed_group_conv2d(cfg, outs):
+@autotvm.register_topi_schedule("group_conv2d_packed.vta")
+def schedule_group_conv2d_packed(cfg, outs):
     """ Schedule the packed conv2d.
     """
     assert len(outs) == 1
index b9edc30..cf6f426 100644 (file)
@@ -246,7 +246,7 @@ if __name__ == '__main__':
     print("Extracting tasks...")
     tasks = extract_from_program(func=relay_prog,
                                  params=params,
-                                 ops=(tvm.relay.op.nn.conv2d,),
+                                 ops=(relay.op.get("nn.conv2d"),),
                                  target=target,
                                  target_host=env.target_host)
 
index af71561..6935e47 100644 (file)
 import json
 import os
 
+import pytest
 import numpy as np
 from collections import namedtuple
 
 import tvm
+from tvm import relay
 from tvm import autotvm
 from tvm.contrib import util
 from tvm.contrib.pickle_memoize import memoize
@@ -79,9 +81,13 @@ def run_conv2d(env, remote, wl, target,
     if "arm_cpu" in target.keys:
         data_pack = False
         layout = "NCHW"
+        conv2d_fcompute = topi.arm_cpu.conv2d_nchw_spatial_pack
+        conv2d_fschedule = topi.arm_cpu.schedule_conv2d_nchw_spatial_pack
     elif "vta" in target.keys:
         data_pack = True
         layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN)
+        conv2d_fcompute = vta.top.conv2d_packed
+        conv2d_fschedule = vta.top.schedule_conv2d_packed
 
     # Derive shapes depending upon packing
     a_shape = (wl.batch, wl.in_filter, wl.height, wl.width)
@@ -101,18 +107,24 @@ def run_conv2d(env, remote, wl, target,
     data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
     bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
+    padding = relay.nn.get_pad_tuple2d((wl.hpad, wl.wpad))
 
     # Define base computation schedule
     with target:
-        res = topi.nn.conv2d(
-            data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1),
-            layout, env.acc_dtype)
+        if data_pack:
+            res = conv2d_fcompute(
+                data, kernel, (wl.hstride, wl.wstride), padding, (1, 1),
+                layout, env.acc_dtype)
+        else:
+            res = conv2d_fcompute(
+                data, kernel, (wl.hstride, wl.wstride), padding, (1, 1),
+                env.acc_dtype)
         res = topi.right_shift(res, 8)
         res = topi.add(res, bias)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
         # Derive base schedule
-        s = topi.generic.schedule_conv2d_nchw([res])
+        s = conv2d_fschedule([res])
         if print_ir:
             print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
 
@@ -222,7 +234,8 @@ def run_conv2d(env, remote, wl, target,
 
     return correct, cost, stats
 
-def test_conv2d(device="vta"):
+@pytest.mark.parametrize("device", ["vta", "arm_cpu"])
+def test_conv2d(device):
     def _run(env, remote):
         if device == "vta":
             target = env.target
index d729fa5..2d96a73 100644 (file)
 import json
 import os
 
+import pytest
 import numpy as np
 from collections import namedtuple
 
 import tvm
+from tvm import relay
 from tvm import autotvm
 from tvm.contrib import util
 from tvm.contrib.pickle_memoize import memoize
@@ -80,14 +82,18 @@ def run_conv2d_transpose(env, remote, wl, target,
     if "arm_cpu" in target.keys:
         data_pack = False
         layout = "NCHW"
+        fcompute = topi.arm_cpu.conv2d_transpose_nchw
+        fschedule = topi.arm_cpu.schedule_conv2d_transpose_nchw
     elif "vta" in target.keys:
         data_pack = True
         layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN)
+        fcompute = vta.top.conv2d_transpose_packed
+        fschedule = vta.top.schedule_conv2d_transpose_packed
 
     # Derive shapes depending upon packing
 
     a_shape = (wl.batch, wl.in_filter, wl.height, wl.width)
-    w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
+    w_shape = (wl.in_filter, wl.out_filter, wl.hkernel, wl.wkernel)
     if data_pack:
         data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN,
                       wl.height, wl.width, env.BATCH, env.BLOCK_IN)
@@ -98,16 +104,17 @@ def run_conv2d_transpose(env, remote, wl, target,
         kernel_shape = w_shape
     data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
+    padding = relay.nn.get_pad_tuple2d((wl.hpad, wl.wpad))
 
     # Define base computation schedule
     with target:
-        res = topi.nn.conv2d_transpose_nchw(
-            data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), env.acc_dtype)
+        res = fcompute(
+            data, kernel, (wl.hstride, wl.wstride), padding, env.acc_dtype)
         res = topi.right_shift(res, env.WGT_WIDTH)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
         # Derive base schedule
-        s = topi.generic.schedule_conv2d_transpose_nchw([res])
+        s = fschedule([res])
         if print_ir:
             print(vta.lower(s, [data, kernel, res], simple_mode=True))
 
@@ -210,7 +217,8 @@ def run_conv2d_transpose(env, remote, wl, target,
 
     return correct, cost, stats
 
-def test_conv2d_transpose(device="vta"):
+@pytest.mark.parametrize("device", ["vta", "arm_cpu"])
+def test_conv2d_transpose(device):
     def _run(env, remote):
         if device == "vta":
             target = env.target
@@ -227,5 +235,5 @@ def test_conv2d_transpose(device="vta"):
     vta.testing.run(_run)
 
 if __name__ == "__main__":
-    test_conv2d_transpose(device="arm_cpu")
+    test_conv2d_transpose(device="arm_cpu")
     test_conv2d_transpose(device="vta")
index b0ee2f5..a0acdc3 100644 (file)
@@ -63,21 +63,25 @@ def run_gemm(env, remote, target,
                       env.BATCH, env.BLOCK_IN)
         kernel_shape = (out_feat//env.BLOCK_OUT, in_feat//env.BLOCK_IN,
                         env.BLOCK_OUT, env.BLOCK_IN)
+        fcompute = vta.top.dense_packed
+        fschedule = vta.top.schedule_dense_packed
     else:
         data_shape = a_shape
         kernel_shape = w_shape
+        fcompute = topi.x86.dense_nopack
+        fschedule = topi.x86.schedule_dense_nopack
     data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
 
     # Define base computation schedule
     with target:
-        res = topi.nn.dense(
-            data, kernel, out_dtype=env.acc_dtype)
+        res = fcompute(
+            data, kernel, None, env.acc_dtype)
         res = topi.right_shift(res, 8)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
         # Derive base schedule
-        s = topi.generic.schedule_dense([res])
+        s = fschedule([res])
         if print_ir:
             print(vta.lower(s, [data, kernel, res], simple_mode=True))
 
index 7bba244..31fef49 100644 (file)
 import json
 import os
 
+import pytest
 import numpy as np
 from collections import namedtuple
 
 import tvm
+from tvm import relay
 from tvm import autotvm
 from tvm.contrib import util
 import topi
@@ -75,9 +77,13 @@ def run_group_conv2d(env, remote, wl, target,
     if "arm_cpu" in target.keys:
         data_pack = False
         layout = "NCHW"
+        fcompute = topi.nn.group_conv2d_nchw
+        fschedule = topi.generic.schedule_group_conv2d_nchw
     elif "vta" in target.keys:
         data_pack = True
         layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN)
+        fcompute = vta.top.group_conv2d_packed
+        fschedule = vta.top.schedule_group_conv2d_packed
 
     # Derive shapes depending upon packing
     CI_G = wl.in_filter // wl.groups
@@ -98,17 +104,19 @@ def run_group_conv2d(env, remote, wl, target,
     data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
     kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
     bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
+    padding = relay.nn.get_pad_tuple2d((wl.hpad, wl.wpad))
+
     # Define base computation schedule
     with target:
-        res = topi.nn.group_conv2d_nchw(
-            data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1),
+        res = fcompute(
+            data, kernel, (wl.hstride, wl.wstride), padding, (1, 1),
             wl.groups, env.acc_dtype)
         res = topi.right_shift(res, 8)
         res = topi.add(res, bias)
         res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
         res = topi.cast(res, env.out_dtype)
         # Derive base schedule
-        s = topi.generic.schedule_group_conv2d_nchw([res])
+        s = fschedule([res])
         if print_ir:
             print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
 
@@ -219,7 +227,8 @@ def run_group_conv2d(env, remote, wl, target,
 
     return correct, cost, stats
 
-def test_conv2d(device="vta"):
+@pytest.mark.parametrize("device", ["vta", "arm_cpu"])
+def test_conv2d(device):
     def _run(env, remote):
         if device == "vta":
             target = env.target
index 94fba3d..a20b8ec 100644 (file)
@@ -295,7 +295,7 @@ def tune_tasks(tasks,
 
 
 def register_vta_tuning_tasks():
-    from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args
+    from tvm.autotvm.task import TaskExtractEnv
 
     @tvm.tag_scope(tag=topi.tag.ELEMWISE)
     def my_clip(x, a_min, a_max):
@@ -309,20 +309,19 @@ def register_vta_tuning_tasks():
     # init autotvm env to register VTA operator
     TaskExtractEnv()
 
-    @autotvm.task.register("topi_nn_conv2d", override=True)
+    @autotvm.register_customized_task("conv2d_packed.vta")
     def _topi_nn_conv2d(*args, **kwargs):
         assert not kwargs, "Do not support kwargs in template function call"
-        args = deserialize_args(args)
         A, W = args[:2]
 
         with tvm.target.vta():
-            res = topi.nn.conv2d(*args, **kwargs)
+            res = vta.top.conv2d_packed(*args, **kwargs)
             res = topi.right_shift(res, 8)
             res = my_clip(res, 0, 127)
             res = topi.cast(res, "int8")
 
         if tvm.target.Target.current().device_name == 'vta':
-            s = topi.generic.schedule_conv2d_nchw([res])
+            s = vta.top.schedule_conv2d_packed([res])
         else:
             s = tvm.create_schedule([res.op])
         return s, [A, W, res]
@@ -356,10 +355,13 @@ def tune_and_evaluate(tuning_opt):
     mod = tvm.IRModule.from_expr(relay_prog)
     tasks = autotvm.task.extract_from_program(mod,
                                               params=params,
-                                              ops=(tvm.relay.op.nn.conv2d, ),
+                                              ops=(relay.op.get("nn.conv2d"),),
                                               target=target,
                                               target_host=env.target_host)
 
+    # filter out non-packed conv2d task
+    tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks))
+
     # We should have extracted 10 convolution tasks
     assert len(tasks) == 10
     print("Extracted {} conv2d tasks:".format(len(tasks)))